kyrylokumar commited on
Commit
4d898ee
·
verified ·
1 Parent(s): da989cb

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +42 -0
  2. __pycache__/main.cpython-310.pyc +0 -0
  3. __pycache__/main1.cpython-310.pyc +0 -0
  4. __pycache__/main2.cpython-310.pyc +0 -0
  5. __pycache__/utils.cpython-310.pyc +0 -0
  6. cnn_dailymail/test.csv +3 -0
  7. cnn_dailymail/train.csv +3 -0
  8. cnn_dailymail/validation.csv +3 -0
  9. last_layer.py +399 -0
  10. main.py +370 -0
  11. main1.py +425 -0
  12. main2.py +511 -0
  13. model0.bin +3 -0
  14. model1.bin +3 -0
  15. model2.bin +3 -0
  16. newspaper-text-summarization-cnn-dailymail.zip +3 -0
  17. utils.py +124 -0
  18. wandb/debug-internal.log +22 -0
  19. wandb/debug.log +27 -0
  20. wandb/run-20241028_085547-mga40p7t/files/code/main.py +124 -0
  21. wandb/run-20241028_085547-mga40p7t/files/config.yaml +69 -0
  22. wandb/run-20241028_085547-mga40p7t/files/output.log +16 -0
  23. wandb/run-20241028_085547-mga40p7t/files/requirements.txt +178 -0
  24. wandb/run-20241028_085547-mga40p7t/files/wandb-metadata.json +60 -0
  25. wandb/run-20241028_085547-mga40p7t/files/wandb-summary.json +1 -0
  26. wandb/run-20241028_085547-mga40p7t/logs/debug-core.log +14 -0
  27. wandb/run-20241028_085547-mga40p7t/logs/debug-internal.log +22 -0
  28. wandb/run-20241028_085547-mga40p7t/logs/debug.log +27 -0
  29. wandb/run-20241028_085547-mga40p7t/run-mga40p7t.wandb +0 -0
  30. wandb/run-20241028_085806-owcrwbil/files/code/main.py +124 -0
  31. wandb/run-20241028_085806-owcrwbil/files/config.yaml +69 -0
  32. wandb/run-20241028_085806-owcrwbil/files/output.log +16 -0
  33. wandb/run-20241028_085806-owcrwbil/files/requirements.txt +178 -0
  34. wandb/run-20241028_085806-owcrwbil/files/wandb-metadata.json +60 -0
  35. wandb/run-20241028_085806-owcrwbil/files/wandb-summary.json +1 -0
  36. wandb/run-20241028_085806-owcrwbil/logs/debug-core.log +13 -0
  37. wandb/run-20241028_085806-owcrwbil/logs/debug-internal.log +22 -0
  38. wandb/run-20241028_085806-owcrwbil/logs/debug.log +27 -0
  39. wandb/run-20241028_085806-owcrwbil/run-owcrwbil.wandb +0 -0
  40. wandb/run-20241028_090044-f9fzz8iy/files/code/main.py +124 -0
  41. wandb/run-20241028_090044-f9fzz8iy/files/config.yaml +51 -0
  42. wandb/run-20241028_090044-f9fzz8iy/files/output.log +15 -0
  43. wandb/run-20241028_090044-f9fzz8iy/files/requirements.txt +178 -0
  44. wandb/run-20241028_090044-f9fzz8iy/files/wandb-metadata.json +60 -0
  45. wandb/run-20241028_090044-f9fzz8iy/files/wandb-summary.json +1 -0
  46. wandb/run-20241028_090044-f9fzz8iy/logs/debug-core.log +13 -0
  47. wandb/run-20241028_090044-f9fzz8iy/logs/debug-internal.log +22 -0
  48. wandb/run-20241028_090044-f9fzz8iy/logs/debug.log +27 -0
  49. wandb/run-20241028_090044-f9fzz8iy/run-f9fzz8iy.wandb +0 -0
  50. wandb/run-20241028_090149-4jbvn26d/files/code/main.py +124 -0
.gitattributes CHANGED
@@ -33,3 +33,45 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ cnn_dailymail/test.csv filter=lfs diff=lfs merge=lfs -text
37
+ cnn_dailymail/train.csv filter=lfs diff=lfs merge=lfs -text
38
+ cnn_dailymail/validation.csv filter=lfs diff=lfs merge=lfs -text
39
+ wandb/run-20241028_090149-4jbvn26d/run-4jbvn26d.wandb filter=lfs diff=lfs merge=lfs -text
40
+ wandb/run-20241028_121715-5kjpxsew/run-5kjpxsew.wandb filter=lfs diff=lfs merge=lfs -text
41
+ wandb/run-20241028_121925-nc6un2i3/run-nc6un2i3.wandb filter=lfs diff=lfs merge=lfs -text
42
+ wandb/run-20241028_123711-98pjnrqo/run-98pjnrqo.wandb filter=lfs diff=lfs merge=lfs -text
43
+ wandb/run-20241028_130604-p34k5wnh/run-p34k5wnh.wandb filter=lfs diff=lfs merge=lfs -text
44
+ wandb/run-20241028_164041-fs40r39w/run-fs40r39w.wandb filter=lfs diff=lfs merge=lfs -text
45
+ wandb/run-20241028_165506-q960l24p/run-q960l24p.wandb filter=lfs diff=lfs merge=lfs -text
46
+ wandb/run-20241028_173226-2kgb0f9e/run-2kgb0f9e.wandb filter=lfs diff=lfs merge=lfs -text
47
+ wandb/run-20241028_183140-jbr9b02q/run-jbr9b02q.wandb filter=lfs diff=lfs merge=lfs -text
48
+ wandb/run-20241028_184435-8g9y4qy1/run-8g9y4qy1.wandb filter=lfs diff=lfs merge=lfs -text
49
+ wandb/run-20241028_191117-1eqkbgu2/run-1eqkbgu2.wandb filter=lfs diff=lfs merge=lfs -text
50
+ wandb/run-20241028_193646-r4qgqj1u/run-r4qgqj1u.wandb filter=lfs diff=lfs merge=lfs -text
51
+ wandb/run-20241028_195038-eakkycgw/run-eakkycgw.wandb filter=lfs diff=lfs merge=lfs -text
52
+ wandb/run-20241028_201220-6xkcnm4u/run-6xkcnm4u.wandb filter=lfs diff=lfs merge=lfs -text
53
+ wandb/run-20241028_202354-zr7kt8eh/run-zr7kt8eh.wandb filter=lfs diff=lfs merge=lfs -text
54
+ wandb/run-20241028_210645-j0itro1g/run-j0itro1g.wandb filter=lfs diff=lfs merge=lfs -text
55
+ wandb/run-20241028_210725-ho9p1f0s/run-ho9p1f0s.wandb filter=lfs diff=lfs merge=lfs -text
56
+ wandb/run-20241028_212229-ovcu6nj3/run-ovcu6nj3.wandb filter=lfs diff=lfs merge=lfs -text
57
+ wandb/run-20241028_212304-uq1xozlm/run-uq1xozlm.wandb filter=lfs diff=lfs merge=lfs -text
58
+ wandb/run-20241028_213435-qow8er7m/run-qow8er7m.wandb filter=lfs diff=lfs merge=lfs -text
59
+ wandb/run-20241028_213754-m1kaqlt1/run-m1kaqlt1.wandb filter=lfs diff=lfs merge=lfs -text
60
+ wandb/run-20241028_221407-dv4g6q0z/run-dv4g6q0z.wandb filter=lfs diff=lfs merge=lfs -text
61
+ wandb/run-20241028_221423-pxyf2xri/run-pxyf2xri.wandb filter=lfs diff=lfs merge=lfs -text
62
+ wandb/run-20241028_222813-k1xslrgl/run-k1xslrgl.wandb filter=lfs diff=lfs merge=lfs -text
63
+ wandb/run-20241028_223604-ucawfmok/run-ucawfmok.wandb filter=lfs diff=lfs merge=lfs -text
64
+ wandb/run-20241028_224822-nlehsykg/run-nlehsykg.wandb filter=lfs diff=lfs merge=lfs -text
65
+ wandb/run-20241028_225822-xuv6uhuc/run-xuv6uhuc.wandb filter=lfs diff=lfs merge=lfs -text
66
+ wandb/run-20241029_130621-0asef00f/run-0asef00f.wandb filter=lfs diff=lfs merge=lfs -text
67
+ wandb/run-20241029_134007-8a5jhu4s/run-8a5jhu4s.wandb filter=lfs diff=lfs merge=lfs -text
68
+ wandb/run-20241029_141057-uk43a4xl/run-uk43a4xl.wandb filter=lfs diff=lfs merge=lfs -text
69
+ wandb/run-20241029_151508-ya0e0d5g/run-ya0e0d5g.wandb filter=lfs diff=lfs merge=lfs -text
70
+ wandb/run-20241029_155952-fb5ojuk9/run-fb5ojuk9.wandb filter=lfs diff=lfs merge=lfs -text
71
+ wandb/run-20241029_182402-3dknsv44/run-3dknsv44.wandb filter=lfs diff=lfs merge=lfs -text
72
+ wandb/run-20241029_182613-mibkz7zt/run-mibkz7zt.wandb filter=lfs diff=lfs merge=lfs -text
73
+ wandb/run-20241029_183624-haor84lw/run-haor84lw.wandb filter=lfs diff=lfs merge=lfs -text
74
+ wandb/run-20241029_190201-8uotupup/run-8uotupup.wandb filter=lfs diff=lfs merge=lfs -text
75
+ wandb/run-20241029_190305-legb7y4v/run-legb7y4v.wandb filter=lfs diff=lfs merge=lfs -text
76
+ wandb/run-20241029_192824-grmmhjzz/run-grmmhjzz.wandb filter=lfs diff=lfs merge=lfs -text
77
+ wandb/run-20241029_193507-ujpie5pz/run-ujpie5pz.wandb filter=lfs diff=lfs merge=lfs -text
__pycache__/main.cpython-310.pyc ADDED
Binary file (8.91 kB). View file
 
__pycache__/main1.cpython-310.pyc ADDED
Binary file (9.74 kB). View file
 
__pycache__/main2.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (4.03 kB). View file
 
cnn_dailymail/test.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69e091606539b415f768de75dd026bcee37b35b5b50b6088d0a6d0e017559d29
3
+ size 49890690
cnn_dailymail/train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd4ba100d4da4c5fe5414a590a5eca9cb47494e044c179a3aadb96cead676ab7
3
+ size 1262015264
cnn_dailymail/validation.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97dedb1d6fd51f94f74e1e9dfd3d7e175fc6b4f4417d9c1cdc33abd0b5fe54a1
3
+ size 57691847
last_layer.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import inf
2
+ # from utils import *
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import torch.utils
7
+ import torch.utils.data
8
+ # from utils import MyDataset, custom_collate
9
+ from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
10
+ import wandb
11
+ import torch.nn.functional as F
12
+
13
+ import einops
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
15
+
16
+ np.random.seed(123)
17
+ torch.manual_seed(123)
18
+ torch.cuda.random.manual_seed(123)
19
+
20
+ import lightning as L
21
+ import utils
22
+ from torchmetrics.text.rouge import ROUGEScore
23
+ def top_p_sampling(logits, p=0.9, temperature=0.5):
24
+
25
+ # Apply temperature scaling
26
+ logits = logits / temperature
27
+
28
+ # Sort logits and get cumulative probabilities
29
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
30
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
31
+
32
+ # Create a mask for probabilities above the threshold
33
+ sorted_indices_to_remove = cumulative_probs > p
34
+ # Shift the indices to the right to keep also the smallest p
35
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
36
+ sorted_indices_to_remove[..., 0] = 0
37
+
38
+ # Scatter sorted indices to original indices with mask
39
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
40
+ logits[indices_to_remove] = float('-inf') # Set unwanted logits to -inf
41
+
42
+ # Sample from the remaining logits
43
+ probs = F.softmax(logits, dim=-1)
44
+ sampled_indices = torch.multinomial(probs, num_samples=1)
45
+ sampled_indices = sampled_indices.squeeze(1)
46
+
47
+ return sampled_indices
48
+
49
+ class PromptTuningModel(nn.Module):
50
+ def __init__(self, num_prompts=6):
51
+ super().__init__()
52
+ self.num_prompts = num_prompts
53
+
54
+ self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
55
+ self.model.requires_grad_(False)
56
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
57
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
58
+ self.tokenizer.add_special_tokens({'pad_token': '[START]'})
59
+
60
+ self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
61
+ self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
62
+
63
+
64
+ self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
65
+
66
+ tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
67
+ token_embedding = self.model.transformer.wte(tmp[0])
68
+ self.token_embedding = token_embedding
69
+ for _ in range(num_prompts//3-1):
70
+ self.token_embedding = torch.cat([self.token_embedding, token_embedding])
71
+
72
+ # print(self.token_embedding.shape)
73
+ data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
74
+ self.learnable_prompt = nn.Parameter(data, requires_grad=True)
75
+
76
+ # @torch.compile
77
+ def forward(self, X, y):
78
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
79
+ embeddings = self.model.transformer.wte(X, ) # b s d
80
+
81
+ embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
82
+ # mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
83
+ # print(mask.shape)
84
+ # labels = torch.where(y == 50257, -100, y)
85
+ # ignore = torch.ones([X.shape[0], self.num_prompts], dtype=torch.long, device=X.device)*-100
86
+ # labels = torch.cat([ignore, labels], dim=1)
87
+ out = self.model(inputs_embeds = embeddings)
88
+ # print("Out.loss:", out.loss)
89
+ logits = out.logits[:,self.num_prompts:]
90
+ return logits
91
+
92
+ def generate_new(self, X):
93
+ batch_size = X.shape[0]
94
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
95
+ embeddings = self.model.transformer.wte(X)
96
+ embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(batch_size, 1, 1), embeddings], dim=1)
97
+
98
+ cnt = 0
99
+ past_key_values = None
100
+ generated_ids = torch.tensor([], dtype=torch.long, device=X.device).view(batch_size, 0) # Store all generated tokens
101
+
102
+ while cnt < 196:
103
+
104
+ out = self.model(inputs_embeds=embeddings, use_cache=True, past_key_values=past_key_values)
105
+ past_key_values = out.past_key_values
106
+ # print(cnt)
107
+ if cnt == 0:
108
+ logits = out.logits[:, self.num_prompts:]
109
+ else:
110
+ logits = out.logits
111
+
112
+ logits[:, :, 50257:] = -1e4 # Apply after slicing for correct dimensions
113
+
114
+ next_token_ids = top_p_sampling(logits[:, -1, :])
115
+ # next_token_ids will have shape (batch_size,)
116
+ print(next_token_ids.shape)
117
+ exit()
118
+ generated_ids = torch.cat([generated_ids, next_token_ids.unsqueeze(-1)], dim=-1)
119
+
120
+ embeddings = self.model.transformer.wte(next_token_ids) # Correctly obtains embeddings for current batch
121
+
122
+
123
+ cnt += 1
124
+
125
+ #Check if all sequences have reached the <end> token
126
+ if torch.all((generated_ids == self.eot.item()).any(dim=-1)): # Check each sequence independently
127
+ break
128
+
129
+ return generated_ids
130
+ def generate(self, X):
131
+ # Only bs = 1
132
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
133
+ embeddings = self.model.transformer.wte(X, ) # b s d
134
+ embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
135
+
136
+ cnt = 0
137
+ past_key_values = None
138
+ final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
139
+ while cnt < 196:
140
+ out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
141
+ # print(cnt, out.logits.shape)
142
+ past_key_values = out.past_key_values
143
+ if cnt == 0:
144
+ logits = out.logits[:,self.num_prompts:]
145
+ logits[:,:, 50257:] = -1e4
146
+
147
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:, None]
148
+
149
+ # print(output.shape)
150
+
151
+ final_prediction = torch.cat([final_prediction, output], dim=1)
152
+ # print(output.shape)
153
+ embeddings = self.model.transformer.wte(output)
154
+ # print(embeddings.shape)
155
+
156
+
157
+ else:
158
+ # print(logits.shape)
159
+ logits = out.logits
160
+ logits[:, :, 50257:] = -1e4
161
+
162
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
163
+ # print(output)
164
+ final_prediction = torch.cat([final_prediction, output], dim=1)
165
+ # print(final_prediction.shape, 'final')
166
+ embeddings = self.model.transformer.wte(output)
167
+
168
+
169
+
170
+ cnt += 1
171
+ # print(output.shape, self.eot.shape)
172
+ if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
173
+ break
174
+
175
+ return final_prediction
176
+
177
+
178
+
179
+
180
+ class LMModel(nn.Module):
181
+ def __init__(self, num_prompts=0):
182
+ super().__init__()
183
+ self.num_prompts = num_prompts
184
+
185
+ self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
186
+ self.model.requires_grad_(False)
187
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
188
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
189
+ self.tokenizer.add_special_tokens({'pad_token': '[START]'})
190
+
191
+ self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
192
+ self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
193
+
194
+
195
+ self.model.lm_head.requires_grad_(True)
196
+
197
+ # @torch.compile
198
+ def forward(self, X, y):
199
+ embeddings = self.model.transformer.wte(X, ) # b s d
200
+ logits = self.model(inputs_embeds = embeddings).logits
201
+ return logits
202
+
203
+ def generate(self, X):
204
+ # Only bs = 1
205
+ # self.learnable_prompt = self.learnable_prompt.to(X.device)
206
+ embeddings = self.model.transformer.wte(X, ) # b s d
207
+ # embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
208
+
209
+ cnt = 0
210
+ past_key_values = None
211
+ final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
212
+ while cnt < 196:
213
+ out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
214
+ # print(cnt, out.logits.shape)
215
+ past_key_values = out.past_key_values
216
+ if cnt == 0:
217
+ logits = out.logits[:,self.num_prompts:]
218
+ logits[:,:, 50257:] = -1e4
219
+
220
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
221
+
222
+ # print(output.shape)
223
+
224
+ final_prediction = torch.cat([final_prediction, output], dim=1)
225
+ # print(output.shape)
226
+ embeddings = self.model.transformer.wte(output)
227
+ # print(embeddings.shape)
228
+
229
+
230
+ else:
231
+ # print(logits.shape)
232
+ logits = out.logits
233
+ logits[:, :, 50257:] = -1e4
234
+
235
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
236
+ # print(output)
237
+ final_prediction = torch.cat([final_prediction, output], dim=1)
238
+ # print(final_prediction.shape, 'final')
239
+ embeddings = self.model.transformer.wte(output)
240
+
241
+
242
+
243
+ cnt += 1
244
+ # print(output.shape, self.eot.shape)
245
+ if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
246
+ break
247
+
248
+ return final_prediction
249
+
250
+ def zero_after_x(tensor, x):
251
+ """
252
+ Zeros out all elements in each row of a 2D tensor after the first occurrence of x.
253
+
254
+ Args:
255
+ tensor: The input 2D tensor.
256
+ x: The value after which to zero out elements.
257
+
258
+ Returns:
259
+ A new tensor with elements zeroed out after x.
260
+ """
261
+
262
+ mask = (tensor == x).cumsum(dim=1) > 0 # Create a cumulative mask
263
+ result = tensor.where(~mask, torch.ones_like(tensor, dtype=torch.long)*x) #zero out where mask is True
264
+
265
+ return result
266
+
267
+ class LitModelPromptTuning(L.LightningModule):
268
+ def __init__(self, model, lr=1e-4, temperature):
269
+ super().__init__()
270
+ self.model = model
271
+ self.lr = lr
272
+ self.model.temperature = temperature
273
+ tokenize_to_strings = lambda text: self.model.tokenizer.convert_ids_to_tokens(self.model.tokenizer(text)["input_ids"])
274
+ self.rouge = ROUGEScore(tokenizer=tokenize_to_strings)
275
+
276
+ self.save_hyperparameters(ignore=['model'])
277
+
278
+
279
+ def training_step(self, batch, batch_idx):
280
+ X, y = batch
281
+ # for i,j in zip(X[1], y[1]):
282
+ # print(i.item(),j.item())
283
+
284
+ logits = self.model(X, y)
285
+
286
+ logits[:,:, 50257:] = -1e4
287
+ # print(X.shape, y.shape, logits.shape)
288
+ loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
289
+ # print(loss)
290
+ # prob = logits.softmax(dim=-1)[0,-300:]
291
+ # target = y[0, -300:]
292
+ # print('logits',logits[0,-300:][torch.arange(target.numel()), target])
293
+ # print(prob[torch.arange(target.numel()), target])
294
+ # print(prob.argmax(dim=-1), target, X[0, -300:])
295
+ # print(self.model.tokenizer.decode(prob.argmax(dim=-1)), 'gap', self.model.tokenizer.decode(target))
296
+ # print(self.model.tokenizer.decode(X[0,-300:]))
297
+ # x = F.cross_entropy(logits[0,:-1,:].reshape(-1, logits.shape[-1]), target=y[0,:-1].reshape(-1), ignore_index=50257, reduction='none')
298
+ # print(x[-20:])
299
+ # print(self.model.pad)
300
+
301
+ # print(X[0, -300:].shape, target.shape, prob.argmax(dim=-1).shape)
302
+ # for i,j,k in zip(X[0, -300:], target, prob.argmax(dim=-1)):
303
+ # print(self.model.tokenizer.decode(i),'\tx ',self.model.tokenizer.decode(j),'\tx ',self.model.tokenizer.decode(k))
304
+
305
+
306
+ # exit()
307
+
308
+ self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
309
+ return loss
310
+
311
+
312
+ def validation_step(self, batch, batch_idx):
313
+ X, y = batch
314
+
315
+ logits = self.model(X, y)
316
+ logits[:,:, 50257:] = -1e4
317
+ # print(X.shape, y.shape, logits.shape)
318
+ loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
319
+
320
+ self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
321
+ return loss
322
+
323
+ def on_test_epoch_start(self, ):
324
+ self.all_text = []
325
+ self.predicted_text = []
326
+
327
+ def test_step(self, batch, batch_idx):
328
+ if batch_idx == 0:
329
+ return
330
+ X, y = batch
331
+ # print(self.model.tokenizer.batch_decode(X))
332
+ # print(X.shape)
333
+ # with torch.no_grad()
334
+ out = self.model.generate(X)
335
+ # out = zero_after_x(out, self.model.eot.item())
336
+ # print(out.shape, y.shape)
337
+ # print(out, y)
338
+ pred = self.model.tokenizer.batch_decode(out, skip_special_tokens=True)
339
+ gt = self.model.tokenizer.batch_decode(y, skip_special_tokens=True)
340
+
341
+
342
+ print(pred)
343
+ print('GAP')
344
+ print(gt)
345
+ final_score = 0
346
+
347
+ for p,g in zip(pred, gt):
348
+ score = self.rouge(p, g, )
349
+ print(score)
350
+ # exit()
351
+
352
+
353
+ self.log_dict(score, on_step=True, on_epoch=True, logger=True, sync_dist=True)
354
+
355
+
356
+
357
+ def configure_optimizers(self):
358
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
359
+ return optimizer
360
+
361
+ from lightning.pytorch.loggers import WandbLogger
362
+ if __name__ == '__main__':
363
+ torch.set_float32_matmul_precision('medium')
364
+ dl_train, dl_val, dl_test = utils.import_data(bs=25, fraction=0.1)
365
+ # gpt_model = PromptTuningModel(num_prompts=12)
366
+ gpt_model = LMModel(num_prompts=0)
367
+
368
+ # gpt_model = torch.compile(gpt_model)
369
+ model = LitModelPromptTuning(
370
+ model=gpt_model,
371
+ lr=1e-4,
372
+ temperature=0.9,
373
+ epoch = 10
374
+
375
+
376
+ )
377
+ print('Training')
378
+
379
+ logger = WandbLogger(project='Anlp-3')
380
+ trainer = L.Trainer(
381
+ accelerator='gpu',
382
+ # limit_train_batches=1,
383
+ # strategy='auto',
384
+ # strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
385
+ devices=[2],
386
+ default_root_dir=f'./logs/', # Tensorflow can be used to viz
387
+ num_nodes=1,
388
+ num_sanity_val_steps=1, # runs a validation step before stating training
389
+ precision='bf16-mixed', # we use half precision to reduce memory usage
390
+ max_epochs=5,
391
+ check_val_every_n_epoch=1, # run validation every epoch
392
+ log_every_n_steps=20,
393
+ logger=logger,
394
+ # detect_anomaly=True,
395
+ )
396
+
397
+ # trainer.test(model, dataloaders=dl_test)
398
+ trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
399
+ trainer.test(model, dataloaders=dl_test)
main.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import inf
2
+ # from utils import *
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import torch.utils
7
+ import torch.utils.data
8
+ # from utils import MyDataset, custom_collate
9
+ from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
10
+ import wandb
11
+ import torch.nn.functional as F
12
+
13
+ import einops
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
15
+
16
+ np.random.seed(123)
17
+ torch.manual_seed(123)
18
+ torch.cuda.random.manual_seed(123)
19
+
20
+ import lightning as L
21
+ import utils
22
+ from torchmetrics.text.rouge import ROUGEScore
23
+ def top_p_sampling(logits, p=0.9, temperature=0.5):
24
+
25
+ # Apply temperature scaling
26
+ logits = logits / temperature
27
+
28
+ # Sort logits and get cumulative probabilities
29
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
30
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
31
+
32
+ # Create a mask for probabilities above the threshold
33
+ sorted_indices_to_remove = cumulative_probs > p
34
+ # Shift the indices to the right to keep also the smallest p
35
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
36
+ sorted_indices_to_remove[..., 0] = 0
37
+
38
+ # Scatter sorted indices to original indices with mask
39
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
40
+ logits[indices_to_remove] = float('-inf') # Set unwanted logits to -inf
41
+
42
+ # Sample from the remaining logits
43
+ probs = F.softmax(logits, dim=-1)
44
+ sampled_indices = torch.multinomial(probs, num_samples=1)
45
+ sampled_indices = sampled_indices.squeeze(1)
46
+
47
+ return sampled_indices
48
+
49
+ class PromptTuningModel(nn.Module):
50
+ def __init__(self, num_prompts=6):
51
+ super().__init__()
52
+ self.num_prompts = num_prompts
53
+
54
+ self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
55
+ # self.model.generation_config.cache_implementation = "static"
56
+ # self.model.generation_config.max_new_tokens = 256
57
+ self.model.requires_grad_(False)
58
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
59
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
60
+ self.tokenizer.add_special_tokens({'cls_token': '[START]'})
61
+
62
+ self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
63
+ self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
64
+ self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
65
+
66
+
67
+ self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
68
+
69
+ tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
70
+ token_embedding = self.model.transformer.wte(tmp[0])
71
+ self.token_embedding = token_embedding
72
+ for _ in range(num_prompts//3-1):
73
+ self.token_embedding = torch.cat([self.token_embedding, token_embedding])
74
+
75
+ # print(self.token_embedding.shape)
76
+ data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
77
+ self.learnable_prompt = nn.Parameter(data, requires_grad=True)
78
+
79
+ # self.model.transformer.wte.weight[self.start].requires_grad = True
80
+
81
+ # @torch.compile
82
+ def forward(self, X, y):
83
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
84
+ embeddings = self.model.transformer.wte(X, ) # b s d
85
+
86
+ embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
87
+ # mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
88
+ # print(mask.shape)
89
+ # labels = torch.where(y == 50257, -100, y)
90
+ # ignore = torch.ones([X.shape[0], self.num_prompts], dtype=torch.long, device=X.device)*-100
91
+ # labels = torch.cat([ignore, labels], dim=1)
92
+ out = self.model(inputs_embeds = embeddings)
93
+ # print("Out.loss:", out.loss)
94
+ logits = out.logits[:,self.num_prompts:]
95
+ return logits
96
+
97
+ def generate_new(self, X):
98
+ batch_size = X.shape[0]
99
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
100
+ embeddings = self.model.transformer.wte(X)
101
+ embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(batch_size, 1, 1), embeddings], dim=1)
102
+
103
+ cnt = 0
104
+ past_key_values = None
105
+ generated_ids = torch.tensor([], dtype=torch.long, device=X.device).view(batch_size, 0) # Store all generated tokens
106
+
107
+ while cnt < 196:
108
+
109
+ out = self.model(inputs_embeds=embeddings, use_cache=True, past_key_values=past_key_values)
110
+ past_key_values = out.past_key_values
111
+ # print(cnt)
112
+ if cnt == 0:
113
+ logits = out.logits[:, self.num_prompts:]
114
+ else:
115
+ logits = out.logits
116
+
117
+ logits[:, :, 50257:] = -1e4 # Apply after slicing for correct dimensions
118
+
119
+ next_token_ids = top_p_sampling(logits[:, -1, :])
120
+ # next_token_ids will have shape (batch_size,)
121
+ print(next_token_ids.shape)
122
+ exit()
123
+ generated_ids = torch.cat([generated_ids, next_token_ids.unsqueeze(-1)], dim=-1)
124
+
125
+ embeddings = self.model.transformer.wte(next_token_ids) # Correctly obtains embeddings for current batch
126
+
127
+
128
+ cnt += 1
129
+
130
+ #Check if all sequences have reached the <end> token
131
+ if torch.all((generated_ids == self.eot.item()).any(dim=-1)): # Check each sequence independently
132
+ break
133
+
134
+ return generated_ids
135
+ def generate(self, X):
136
+ # Only bs = 1
137
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
138
+ embeddings = self.model.transformer.wte(X, ) # b s d
139
+ embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
140
+
141
+ cnt = 0
142
+ past_key_values = None
143
+ final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
144
+ while cnt < 196:
145
+ out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
146
+ # print(cnt, out.logits.shape)
147
+ past_key_values = out.past_key_values
148
+ if cnt == 0:
149
+ logits = out.logits[:,self.num_prompts:]
150
+ logits[:,:, 50257:] = -1e4
151
+
152
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
153
+
154
+ # print(output.shape)
155
+
156
+ final_prediction = torch.cat([final_prediction, output], dim=1)
157
+ # print(output.shape)
158
+ embeddings = self.model.transformer.wte(output)
159
+ # print(embeddings.shape)
160
+
161
+
162
+ else:
163
+ # print(logits.shape)
164
+ logits = out.logits
165
+ logits[:, :, 50257:] = -1e4
166
+
167
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
168
+ # print(output)
169
+ final_prediction = torch.cat([final_prediction, output], dim=1)
170
+ # print(final_prediction.shape, 'final')
171
+ embeddings = self.model.transformer.wte(output)
172
+
173
+
174
+
175
+ cnt += 1
176
+ # print(output.shape, self.eot.shape)
177
+ if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
178
+ break
179
+
180
+ return final_prediction
181
+
182
+
183
+
184
+
185
+ class LMModel(nn.Module):
186
+ def __init__(self, num_prompts=6):
187
+ super().__init__()
188
+ self.num_prompts = num_prompts
189
+
190
+ self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
191
+ self.model.requires_grad_(False)
192
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
193
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
194
+ self.tokenizer.add_special_tokens({'pad_token': '[START]'})
195
+
196
+ self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
197
+ self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
198
+
199
+
200
+ self.model.lm_head.requires_grad_(True)
201
+
202
+ # @torch.compile
203
+ def forward(self, X):
204
+ embeddings = self.model.transformer.wte(X, ) # b s d
205
+ logits = self.model(inputs_embeds = embeddings).logits
206
+ return logits
207
+
208
+ def zero_after_x(arr, x):
209
+ """
210
+ Zeros out all elements in each row of a 2D tensor after the first occurrence of x.
211
+
212
+ Args:
213
+ tensor: The input 2D tensor.
214
+ x: The value after which to zero out elements.
215
+
216
+ Returns:
217
+ A new tensor with elements zeroed out after x.
218
+ """
219
+
220
+ mask = (arr == x).cumsum(dim=1) > 0 # Create a cumulative mask
221
+ result = torch.where(mask, x, arr) #zero out where mask is True
222
+
223
+ return result
224
+
225
+ class LitModelPromptTuning(L.LightningModule):
226
+ def __init__(self, model, temperature, epoch, lr=1e-4):
227
+ super().__init__()
228
+ self.model = model
229
+ self.lr = lr
230
+ self.model.temperature = temperature
231
+ self.epoch = epoch
232
+ self.temperature = temperature
233
+
234
+ tokenize_to_strings = lambda text: self.model.tokenizer.convert_ids_to_tokens(self.model.tokenizer(text)["input_ids"])
235
+ self.rouge = ROUGEScore(tokenizer=tokenize_to_strings)
236
+
237
+ self.save_hyperparameters(ignore=['model'])
238
+
239
+
240
+ def training_step(self, batch, batch_idx):
241
+ X, y = batch
242
+ # for i,j in zip(X[1], y[1]):
243
+ # print(i.item(),j.item())
244
+
245
+ logits = self.model(X, y)
246
+
247
+ logits[:,:, 50257:] = -1e4
248
+ # print(X.shape, y.shape, logits.shape)
249
+ loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
250
+ # print(loss)
251
+ # prob = logits.softmax(dim=-1)[0,-300:]
252
+ # target = y[0, -300:]
253
+ # print('logits',logits[0,-300:][torch.arange(target.numel()), target])
254
+ # print(prob[torch.arange(target.numel()), target])
255
+ # print(prob.argmax(dim=-1), target, X[0, -300:])
256
+ # print(self.model.tokenizer.decode(prob.argmax(dim=-1)), 'gap', self.model.tokenizer.decode(target))
257
+ # print(self.model.tokenizer.decode(X[0,-300:]))
258
+ # x = F.cross_entropy(logits[0,:-1,:].reshape(-1, logits.shape[-1]), target=y[0,:-1].reshape(-1), ignore_index=50257, reduction='none')
259
+ # print(x[-20:])
260
+ # print(self.model.pad)
261
+
262
+ # print(X[0, -300:].shape, target.shape, prob.argmax(dim=-1).shape)
263
+ # for i,j,k in zip(X[0, -300:], target, prob.argmax(dim=-1)):
264
+ # print(self.model.tokenizer.decode(i),'\tx ',self.model.tokenizer.decode(j),'\tx ',self.model.tokenizer.decode(k))
265
+
266
+
267
+ # exit()
268
+
269
+ self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
270
+ return loss
271
+
272
+
273
+ def validation_step(self, batch, batch_idx):
274
+ X, y = batch
275
+
276
+ logits = self.model(X, y)
277
+ logits[:,:, 50257:] = -1e4
278
+ # print(X.shape, y.shape, logits.shape)
279
+ loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
280
+
281
+ self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
282
+ return loss
283
+
284
+ def on_test_epoch_start(self, ):
285
+ self.all_text = []
286
+ self.predicted_text = []
287
+
288
+ def test_step(self, batch, batch_idx):
289
+ if batch_idx == 0:
290
+ return
291
+ X, y = batch
292
+ # print(self.model.tokenizer.batch_decode(X))
293
+ # print(X.shape)
294
+ # with torch.no_grad()
295
+ out = self.model.generate(X)
296
+ # out = zero_after_x(out, self.model.eot.to(X.device))
297
+ # print(out.shape, y.shape)
298
+ # print(out, y)
299
+ pred = self.model.tokenizer.batch_decode(out, skip_special_tokens=False)
300
+ gt = self.model.tokenizer.batch_decode(y, skip_special_tokens=False)
301
+
302
+
303
+ print(pred)
304
+ print('GAP')
305
+ print(gt)
306
+ final_score = 0
307
+
308
+ for p,g in zip(pred, gt):
309
+ score = self.rouge(p, g, )
310
+ print(score)
311
+ # exit()
312
+
313
+
314
+ self.log_dict(score, on_step=True, on_epoch=True, logger=True, sync_dist=True)
315
+
316
+
317
+
318
+ def configure_optimizers(self):
319
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
320
+ return optimizer
321
+
322
+ from lightning.pytorch.loggers import WandbLogger
323
+ if __name__ == '__main__':
324
+
325
+ train = False
326
+
327
+ torch.set_float32_matmul_precision('medium')
328
+ dl_train, dl_val, dl_test = utils.import_data(bs=24, fraction=1)
329
+ if train:
330
+ gpt_model = PromptTuningModel(num_prompts=24)
331
+ gpt_model = torch.compile(gpt_model)
332
+ else:
333
+ gpt_model = torch.load('./model0.bin')
334
+ # gpt_model = LMModel(num_prompts=12)
335
+
336
+
337
+ model = LitModelPromptTuning(
338
+ model=gpt_model,
339
+ lr=1e-3,
340
+ temperature=0.9,
341
+ epoch = 5
342
+ )
343
+ print('Training')
344
+
345
+ logger = WandbLogger(project='Anlp-3')
346
+ trainer = L.Trainer(
347
+ accelerator='gpu',
348
+ # limit_train_batches=1,
349
+ # strategy='auto',
350
+ # strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
351
+ devices=1,
352
+ default_root_dir=f'./logs/', # Tensorflow can be used to viz
353
+ num_nodes=1,
354
+ num_sanity_val_steps=1, # runs a validation step before stating training
355
+ precision='bf16-mixed', # we use half precision to reduce memory usage
356
+ max_epochs=5,
357
+ check_val_every_n_epoch=1, # run validation every epoch
358
+ log_every_n_steps=20,
359
+ logger=logger,
360
+ # detect_anomaly=True,
361
+ )
362
+
363
+ # trainer.test(model, dataloaders=dl_test)
364
+ if train:
365
+ trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
366
+ trainer.test(model, dataloaders=dl_test)
367
+ torch.save(model.model, './model0.bin')
368
+ else:
369
+ trainer.test(model, dataloaders=dl_test)
370
+
main1.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import inf
2
+ # from utils import *
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import torch.utils
7
+ import torch.utils.data
8
+ # from utils import MyDataset, custom_collate
9
+ from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
10
+ import wandb
11
+ import torch.nn.functional as F
12
+
13
+ import einops
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
15
+
16
+ np.random.seed(123)
17
+ torch.manual_seed(123)
18
+ torch.cuda.random.manual_seed(123)
19
+
20
+ import lightning as L
21
+ import utils
22
+ from torchmetrics.text.rouge import ROUGEScore
23
+ def top_p_sampling(logits, p=0.9, temperature=0.5):
24
+
25
+ # Apply temperature scaling
26
+ logits = logits / temperature
27
+
28
+ # Sort logits and get cumulative probabilities
29
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
30
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
31
+
32
+ # Create a mask for probabilities above the threshold
33
+ sorted_indices_to_remove = cumulative_probs > p
34
+ # Shift the indices to the right to keep also the smallest p
35
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
36
+ sorted_indices_to_remove[..., 0] = 0
37
+
38
+ # Scatter sorted indices to original indices with mask
39
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
40
+ logits[indices_to_remove] = float('-inf') # Set unwanted logits to -inf
41
+
42
+ # Sample from the remaining logits
43
+ probs = F.softmax(logits, dim=-1)
44
+ sampled_indices = torch.multinomial(probs, num_samples=1)
45
+ sampled_indices = sampled_indices.squeeze(1)
46
+
47
+ return sampled_indices
48
+
49
+ class PromptTuningModel(nn.Module):
50
+ def __init__(self, num_prompts=6):
51
+ super().__init__()
52
+ self.num_prompts = num_prompts
53
+
54
+ self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
55
+ # self.model.generation_config.cache_implementation = "static"
56
+ # self.model.generation_config.max_new_tokens = 256
57
+ self.model.requires_grad_(False)
58
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
59
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
60
+ self.tokenizer.add_special_tokens({'cls_token': '[START]'})
61
+
62
+ self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
63
+ self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
64
+ self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
65
+
66
+
67
+ self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
68
+
69
+ tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
70
+ token_embedding = self.model.transformer.wte(tmp[0])
71
+ self.token_embedding = token_embedding
72
+ for _ in range(num_prompts//3-1):
73
+ self.token_embedding = torch.cat([self.token_embedding, token_embedding])
74
+
75
+ # print(self.token_embedding.shape)
76
+ data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
77
+ self.learnable_prompt = nn.Parameter(data, requires_grad=True)
78
+
79
+ # self.model.transformer.wte.weight[self.start].requires_grad = True
80
+
81
+ # @torch.compile
82
+ def forward(self, X, y):
83
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
84
+ embeddings = self.model.transformer.wte(X, ) # b s d
85
+
86
+ embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
87
+ # mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
88
+ # print(mask.shape)
89
+ # labels = torch.where(y == 50257, -100, y)
90
+ # ignore = torch.ones([X.shape[0], self.num_prompts], dtype=torch.long, device=X.device)*-100
91
+ # labels = torch.cat([ignore, labels], dim=1)
92
+ out = self.model(inputs_embeds = embeddings)
93
+ # print("Out.loss:", out.loss)
94
+ logits = out.logits[:,self.num_prompts:]
95
+ return logits
96
+
97
+ def generate_new(self, X):
98
+ batch_size = X.shape[0]
99
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
100
+ embeddings = self.model.transformer.wte(X)
101
+ embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(batch_size, 1, 1), embeddings], dim=1)
102
+
103
+ cnt = 0
104
+ past_key_values = None
105
+ generated_ids = torch.tensor([], dtype=torch.long, device=X.device).view(batch_size, 0) # Store all generated tokens
106
+
107
+ while cnt < 196:
108
+
109
+ out = self.model(inputs_embeds=embeddings, use_cache=True, past_key_values=past_key_values)
110
+ past_key_values = out.past_key_values
111
+ # print(cnt)
112
+ if cnt == 0:
113
+ logits = out.logits[:, self.num_prompts:]
114
+ else:
115
+ logits = out.logits
116
+
117
+ logits[:, :, 50257:] = -1e4 # Apply after slicing for correct dimensions
118
+
119
+ next_token_ids = top_p_sampling(logits[:, -1, :])
120
+ # next_token_ids will have shape (batch_size,)
121
+ print(next_token_ids.shape)
122
+ exit()
123
+ generated_ids = torch.cat([generated_ids, next_token_ids.unsqueeze(-1)], dim=-1)
124
+
125
+ embeddings = self.model.transformer.wte(next_token_ids) # Correctly obtains embeddings for current batch
126
+
127
+
128
+ cnt += 1
129
+
130
+ #Check if all sequences have reached the <end> token
131
+ if torch.all((generated_ids == self.eot.item()).any(dim=-1)): # Check each sequence independently
132
+ break
133
+
134
+ return generated_ids
135
+ def generate(self, X):
136
+ # Only bs = 1
137
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
138
+ embeddings = self.model.transformer.wte(X, ) # b s d
139
+ embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
140
+
141
+ cnt = 0
142
+ past_key_values = None
143
+ final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
144
+ while cnt < 196:
145
+ out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
146
+ # print(cnt, out.logits.shape)
147
+ past_key_values = out.past_key_values
148
+ if cnt == 0:
149
+ logits = out.logits[:,self.num_prompts:]
150
+ logits[:,:, 50257:] = -1e4
151
+
152
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
153
+
154
+ # print(output.shape)
155
+
156
+ final_prediction = torch.cat([final_prediction, output], dim=1)
157
+ # print(output.shape)
158
+ embeddings = self.model.transformer.wte(output)
159
+ # print(embeddings.shape)
160
+
161
+
162
+ else:
163
+ # print(logits.shape)
164
+ logits = out.logits
165
+ logits[:, :, 50257:] = -1e4
166
+
167
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
168
+ # print(output)
169
+ final_prediction = torch.cat([final_prediction, output], dim=1)
170
+ # print(final_prediction.shape, 'final')
171
+ embeddings = self.model.transformer.wte(output)
172
+
173
+
174
+
175
+ cnt += 1
176
+ # print(output.shape, self.eot.shape)
177
+ if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
178
+ break
179
+
180
+ return final_prediction
181
+
182
+
183
+
184
+
185
+ class LMModel(nn.Module):
186
+ def __init__(self, num_prompts=6):
187
+ super().__init__()
188
+ self.num_prompts = num_prompts
189
+
190
+ self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
191
+ # self.model.generation_config.cache_implementation = "static"
192
+ # self.model.generation_config.max_new_tokens = 256
193
+ self.model.requires_grad_(False)
194
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
195
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
196
+ self.tokenizer.add_special_tokens({'cls_token': '[START]'})
197
+
198
+ self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
199
+ self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
200
+ self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
201
+
202
+
203
+ self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
204
+
205
+
206
+ self.model.lm_head.requires_grad_(True)
207
+
208
+ # @torch.compile
209
+ def forward(self, X, y):
210
+ embeddings = self.model.transformer.wte(X, ) # b s d
211
+ logits = self.model(inputs_embeds = embeddings).logits
212
+ return logits
213
+
214
+ def generate(self, X):
215
+ # Only bs = 1
216
+ # self.learnable_prompt = self.learnable_prompt.to(X.device)
217
+ embeddings = self.model.transformer.wte(X, ) # b s d
218
+ # embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
219
+
220
+ cnt = 0
221
+ past_key_values = None
222
+ final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
223
+ while cnt < 196:
224
+ out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
225
+ # print(cnt, out.logits.shape)
226
+ past_key_values = out.past_key_values
227
+ if cnt == 0:
228
+ logits = out.logits[:,self.num_prompts:]
229
+ logits[:,:, 50257:] = -1e4
230
+
231
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
232
+
233
+ # print(output.shape)
234
+
235
+ final_prediction = torch.cat([final_prediction, output], dim=1)
236
+ # print(output.shape)
237
+ embeddings = self.model.transformer.wte(output)
238
+ # print(embeddings.shape)
239
+
240
+
241
+ else:
242
+ # print(logits.shape)
243
+ logits = out.logits
244
+ logits[:, :, 50257:] = -1e4
245
+
246
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
247
+ # print(output)
248
+ final_prediction = torch.cat([final_prediction, output], dim=1)
249
+ # print(final_prediction.shape, 'final')
250
+ embeddings = self.model.transformer.wte(output)
251
+
252
+
253
+
254
+ cnt += 1
255
+ # print(output.shape, self.eot.shape)
256
+ if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
257
+ break
258
+
259
+ return final_prediction
260
+
261
+ def zero_after_x(arr, x):
262
+ """
263
+ Zeros out all elements in each row of a 2D tensor after the first occurrence of x.
264
+
265
+ Args:
266
+ tensor: The input 2D tensor.
267
+ x: The value after which to zero out elements.
268
+
269
+ Returns:
270
+ A new tensor with elements zeroed out after x.
271
+ """
272
+
273
+ mask = (arr == x).cumsum(dim=1) > 0 # Create a cumulative mask
274
+ result = torch.where(mask, x, arr) #zero out where mask is True
275
+
276
+ return result
277
+
278
+ class LitModelPromptTuning(L.LightningModule):
279
+ def __init__(self, model, temperature, epoch, lr=1e-4, **kwargs):
280
+ super().__init__()
281
+ self.model = model
282
+ self.lr = lr
283
+ self.model.temperature = temperature
284
+ self.epoch = epoch
285
+ self.temperature = temperature
286
+
287
+ for key, value in kwargs.items():
288
+ setattr(self, key, value)
289
+
290
+ tokenize_to_strings = lambda text: self.model.tokenizer.convert_ids_to_tokens(self.model.tokenizer(text)["input_ids"])
291
+ self.rouge = ROUGEScore(tokenizer=tokenize_to_strings)
292
+
293
+ self.save_hyperparameters(ignore=['model'])
294
+
295
+
296
+ def training_step(self, batch, batch_idx):
297
+ X, y = batch
298
+ # for i,j in zip(X[1], y[1]):
299
+ # print(i.item(),j.item())
300
+
301
+ logits = self.model(X, y)
302
+
303
+ logits[:,:, 50257:] = -1e4
304
+ # print(X.shape, y.shape, logits.shape)
305
+ loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
306
+ # print(loss)
307
+ # prob = logits.softmax(dim=-1)[0,-300:]
308
+ # target = y[0, -300:]
309
+ # print('logits',logits[0,-300:][torch.arange(target.numel()), target])
310
+ # print(prob[torch.arange(target.numel()), target])
311
+ # print(prob.argmax(dim=-1), target, X[0, -300:])
312
+ # print(self.model.tokenizer.decode(prob.argmax(dim=-1)), 'gap', self.model.tokenizer.decode(target))
313
+ # print(self.model.tokenizer.decode(X[0,-300:]))
314
+ # x = F.cross_entropy(logits[0,:-1,:].reshape(-1, logits.shape[-1]), target=y[0,:-1].reshape(-1), ignore_index=50257, reduction='none')
315
+ # print(x[-20:])
316
+ # print(self.model.pad)
317
+
318
+ # print(X[0, -300:].shape, target.shape, prob.argmax(dim=-1).shape)
319
+ # for i,j,k in zip(X[0, -300:], target, prob.argmax(dim=-1)):
320
+ # print(self.model.tokenizer.decode(i),'\tx ',self.model.tokenizer.decode(j),'\tx ',self.model.tokenizer.decode(k))
321
+
322
+
323
+ # exit()
324
+
325
+ self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
326
+ return loss
327
+
328
+
329
+ def validation_step(self, batch, batch_idx):
330
+ X, y = batch
331
+
332
+ logits = self.model(X, y)
333
+ logits[:,:, 50257:] = -1e4
334
+ # print(X.shape, y.shape, logits.shape)
335
+ loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
336
+
337
+ self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
338
+ return loss
339
+
340
+ def on_test_epoch_start(self, ):
341
+ self.all_text = []
342
+ self.predicted_text = []
343
+
344
+ def test_step(self, batch, batch_idx):
345
+ if batch_idx == 0:
346
+ return
347
+ X, y = batch
348
+ # print(self.model.tokenizer.batch_decode(X))
349
+ # print(X.shape)
350
+ # with torch.no_grad()
351
+ out = self.model.generate(X)
352
+ # out = zero_after_x(out, self.model.eot.to(X.device))
353
+ # print(out.shape, y.shape)
354
+ # print(out, y)
355
+ pred = self.model.tokenizer.batch_decode(out, skip_special_tokens=False)
356
+ gt = self.model.tokenizer.batch_decode(y, skip_special_tokens=False)
357
+
358
+
359
+ print(pred)
360
+ print('GAP')
361
+ print(gt)
362
+ final_score = 0
363
+
364
+ for p,g in zip(pred, gt):
365
+ score = self.rouge(p, g, )
366
+ print(score)
367
+ # exit()
368
+
369
+
370
+ self.log_dict(score, on_step=True, on_epoch=True, logger=True, sync_dist=True)
371
+
372
+
373
+
374
+ def configure_optimizers(self):
375
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
376
+ return optimizer
377
+
378
+ from lightning.pytorch.loggers import WandbLogger
379
+ if __name__ == '__main__':
380
+ train = False
381
+
382
+ torch.set_float32_matmul_precision('medium')
383
+ dl_train, dl_val, dl_test = utils.import_data(bs=24, fraction=1)
384
+ # gpt_model = PromptTuningModel(num_prompts=24)
385
+ if train:
386
+ gpt_model = LMModel(num_prompts=12)
387
+ gpt_model = torch.compile(gpt_model)
388
+ else:
389
+ gpt_model = torch.load('./model1.bin')
390
+
391
+
392
+ model = LitModelPromptTuning(
393
+ model=gpt_model,
394
+ lr=1e-4,
395
+ temperature=0.9,
396
+ epoch = 5,
397
+
398
+ type_model = 'lm_head'
399
+ )
400
+ print('Training')
401
+
402
+ logger = WandbLogger(project='Anlp-3')
403
+ trainer = L.Trainer(
404
+ accelerator='gpu',
405
+ # limit_train_batches=1,
406
+ # strategy='auto',
407
+ # strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
408
+ devices=1,
409
+ default_root_dir=f'./logs/', # Tensorflow can be used to viz
410
+ num_nodes=1,
411
+ num_sanity_val_steps=1, # runs a validation step before stating training
412
+ precision='bf16-mixed', # we use half precision to reduce memory usage
413
+ max_epochs=5,
414
+ check_val_every_n_epoch=1, # run validation every epoch
415
+ log_every_n_steps=20,
416
+ logger=logger,
417
+ # detect_anomaly=True,
418
+ )
419
+
420
+ if train:
421
+ trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
422
+ trainer.test(model, dataloaders=dl_test)
423
+ torch.save(model.model, './model1.bin')
424
+ else:
425
+ trainer.test(model, dataloaders=dl_test)
main2.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import inf
2
+ # from utils import *
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import torch.utils
7
+ import torch.utils.data
8
+ # from utils import MyDataset, custom_collate
9
+ from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
10
+ import wandb
11
+ import torch.nn.functional as F
12
+
13
+ import einops
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
15
+
16
+ np.random.seed(123)
17
+ torch.manual_seed(123)
18
+ torch.cuda.random.manual_seed(123)
19
+
20
+ import lightning as L
21
+ import utils
22
+ from torchmetrics.text.rouge import ROUGEScore
23
+ def top_p_sampling(logits, p=0.9, temperature=0.5):
24
+
25
+ # Apply temperature scaling
26
+ logits = logits / temperature
27
+
28
+ # Sort logits and get cumulative probabilities
29
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
30
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
31
+
32
+ # Create a mask for probabilities above the threshold
33
+ sorted_indices_to_remove = cumulative_probs > p
34
+ # Shift the indices to the right to keep also the smallest p
35
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
36
+ sorted_indices_to_remove[..., 0] = 0
37
+
38
+ # Scatter sorted indices to original indices with mask
39
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
40
+ logits[indices_to_remove] = float('-inf') # Set unwanted logits to -inf
41
+
42
+ # Sample from the remaining logits
43
+ probs = F.softmax(logits, dim=-1)
44
+ sampled_indices = torch.multinomial(probs, num_samples=1)
45
+ sampled_indices = sampled_indices.squeeze(1)
46
+
47
+ return sampled_indices
48
+
49
+ class PromptTuningModel(nn.Module):
50
+ def __init__(self, num_prompts=6):
51
+ super().__init__()
52
+ self.num_prompts = num_prompts
53
+
54
+ self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
55
+ # self.model.generation_config.cache_implementation = "static"
56
+ # self.model.generation_config.max_new_tokens = 256
57
+ self.model.requires_grad_(False)
58
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
59
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
60
+ self.tokenizer.add_special_tokens({'cls_token': '[START]'})
61
+
62
+ self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
63
+ self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
64
+ self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
65
+
66
+
67
+ self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
68
+
69
+ tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
70
+ token_embedding = self.model.transformer.wte(tmp[0])
71
+ self.token_embedding = token_embedding
72
+ for _ in range(num_prompts//3-1):
73
+ self.token_embedding = torch.cat([self.token_embedding, token_embedding])
74
+
75
+ # print(self.token_embedding.shape)
76
+ data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
77
+ self.learnable_prompt = nn.Parameter(data, requires_grad=True)
78
+
79
+ # self.model.transformer.wte.weight[self.start].requires_grad = True
80
+
81
+ # @torch.compile
82
+ def forward(self, X, y):
83
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
84
+ embeddings = self.model.transformer.wte(X, ) # b s d
85
+
86
+ embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
87
+ # mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
88
+ # print(mask.shape)
89
+ # labels = torch.where(y == 50257, -100, y)
90
+ # ignore = torch.ones([X.shape[0], self.num_prompts], dtype=torch.long, device=X.device)*-100
91
+ # labels = torch.cat([ignore, labels], dim=1)
92
+ out = self.model(inputs_embeds = embeddings)
93
+ # print("Out.loss:", out.loss)
94
+ logits = out.logits[:,self.num_prompts:]
95
+ return logits
96
+
97
+ def generate_new(self, X):
98
+ batch_size = X.shape[0]
99
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
100
+ embeddings = self.model.transformer.wte(X)
101
+ embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(batch_size, 1, 1), embeddings], dim=1)
102
+
103
+ cnt = 0
104
+ past_key_values = None
105
+ generated_ids = torch.tensor([], dtype=torch.long, device=X.device).view(batch_size, 0) # Store all generated tokens
106
+
107
+ while cnt < 196:
108
+
109
+ out = self.model(inputs_embeds=embeddings, use_cache=True, past_key_values=past_key_values)
110
+ past_key_values = out.past_key_values
111
+ # print(cnt)
112
+ if cnt == 0:
113
+ logits = out.logits[:, self.num_prompts:]
114
+ else:
115
+ logits = out.logits
116
+
117
+ logits[:, :, 50257:] = -1e4 # Apply after slicing for correct dimensions
118
+
119
+ next_token_ids = top_p_sampling(logits[:, -1, :])
120
+ # next_token_ids will have shape (batch_size,)
121
+ print(next_token_ids.shape)
122
+ exit()
123
+ generated_ids = torch.cat([generated_ids, next_token_ids.unsqueeze(-1)], dim=-1)
124
+
125
+ embeddings = self.model.transformer.wte(next_token_ids) # Correctly obtains embeddings for current batch
126
+
127
+
128
+ cnt += 1
129
+
130
+ #Check if all sequences have reached the <end> token
131
+ if torch.all((generated_ids == self.eot.item()).any(dim=-1)): # Check each sequence independently
132
+ break
133
+
134
+ return generated_ids
135
+ def generate(self, X):
136
+ # Only bs = 1
137
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
138
+ embeddings = self.model.transformer.wte(X, ) # b s d
139
+ embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
140
+
141
+ cnt = 0
142
+ past_key_values = None
143
+ final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
144
+ while cnt < 196:
145
+ out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
146
+ # print(cnt, out.logits.shape)
147
+ past_key_values = out.past_key_values
148
+ if cnt == 0:
149
+ logits = out.logits[:,self.num_prompts:]
150
+ logits[:,:, 50257:] = -1e4
151
+
152
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
153
+
154
+ # print(output.shape)
155
+
156
+ final_prediction = torch.cat([final_prediction, output], dim=1)
157
+ # print(output.shape)
158
+ embeddings = self.model.transformer.wte(output)
159
+ # print(embeddings.shape)
160
+
161
+
162
+ else:
163
+ # print(logits.shape)
164
+ logits = out.logits
165
+ logits[:, :, 50257:] = -1e4
166
+
167
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
168
+ # print(output)
169
+ final_prediction = torch.cat([final_prediction, output], dim=1)
170
+ # print(final_prediction.shape, 'final')
171
+ embeddings = self.model.transformer.wte(output)
172
+
173
+
174
+
175
+ cnt += 1
176
+ # print(output.shape, self.eot.shape)
177
+ if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
178
+ break
179
+
180
+ return final_prediction
181
+ from peft import PeftModel, LoraConfig, get_peft_model
182
+ class LoraModel(nn.Module):
183
+ def __init__(self, dim=8):
184
+ super().__init__()
185
+ self.num_prompts = 0
186
+ self.dim = dim
187
+
188
+ self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
189
+ # self.model.generation_config.cache_implementation = "static"
190
+ # self.model.generation_config.max_new_tokens = 256
191
+ self.model.requires_grad_(False)
192
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
193
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
194
+ self.tokenizer.add_special_tokens({'cls_token': '[START]'})
195
+
196
+ self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
197
+ self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
198
+ self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
199
+
200
+
201
+ self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
202
+
203
+ lora_config = LoraConfig(
204
+ r=dim, # Rank of the low-rank matrices (adjust as needed)
205
+ lora_alpha=32, # Scaling factor (adjust as needed)
206
+ target_modules=["c_attn"], # Or other target modules as needed
207
+ lora_dropout=0.05, # Dropout probability for LoRA layers
208
+ bias="none", #Bias type for LoRA. Use "none" for no bias.
209
+ task_type="CAUSAL_LM" #Specify task type.
210
+ )
211
+ self.model = get_peft_model(self.model, lora_config)
212
+
213
+
214
+
215
+
216
+
217
+ # @torch.compile
218
+ def forward(self, X, y):
219
+ embeddings = self.model.transformer.wte(X, ) # b s d
220
+ logits = self.model(inputs_embeds = embeddings).logits
221
+ return logits
222
+
223
+ def generate(self, X):
224
+ # Only bs = 1
225
+ # self.learnable_prompt = self.learnable_prompt.to(X.device)
226
+ embeddings = self.model.transformer.wte(X, ) # b s d
227
+ # embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
228
+
229
+ cnt = 0
230
+ past_key_values = None
231
+ final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
232
+ while cnt < 196:
233
+ out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
234
+ # print(cnt, out.logits.shape)
235
+ past_key_values = out.past_key_values
236
+ if cnt == 0:
237
+ logits = out.logits[:,self.num_prompts:]
238
+ logits[:,:, 50257:] = -1e4
239
+
240
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
241
+
242
+ # print(output.shape)
243
+
244
+ final_prediction = torch.cat([final_prediction, output], dim=1)
245
+ # print(output.shape)
246
+ embeddings = self.model.transformer.wte(output)
247
+ # print(embeddings.shape)
248
+
249
+
250
+ else:
251
+ # print(logits.shape)
252
+ logits = out.logits
253
+ logits[:, :, 50257:] = -1e4
254
+
255
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
256
+ # print(output)
257
+ final_prediction = torch.cat([final_prediction, output], dim=1)
258
+ # print(final_prediction.shape, 'final')
259
+ embeddings = self.model.transformer.wte(output)
260
+
261
+
262
+
263
+ cnt += 1
264
+ # print(output.shape, self.eot.shape)
265
+ if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
266
+ break
267
+
268
+ return final_prediction
269
+
270
+
271
+ class LMModel(nn.Module):
272
+ def __init__(self, num_prompts=6):
273
+ super().__init__()
274
+ self.num_prompts = num_prompts
275
+
276
+ self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
277
+ # self.model.generation_config.cache_implementation = "static"
278
+ # self.model.generation_config.max_new_tokens = 256
279
+ self.model.requires_grad_(False)
280
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
281
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
282
+ self.tokenizer.add_special_tokens({'cls_token': '[START]'})
283
+
284
+ self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
285
+ self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
286
+ self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
287
+
288
+
289
+ self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
290
+
291
+
292
+ self.model.lm_head.requires_grad_(True)
293
+
294
+ # @torch.compile
295
+ def forward(self, X, y):
296
+ embeddings = self.model.transformer.wte(X, ) # b s d
297
+ logits = self.model(inputs_embeds = embeddings).logits
298
+ return logits
299
+
300
+ def generate(self, X):
301
+ # Only bs = 1
302
+ # self.learnable_prompt = self.learnable_prompt.to(X.device)
303
+ embeddings = self.model.transformer.wte(X, ) # b s d
304
+ # embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
305
+
306
+ cnt = 0
307
+ past_key_values = None
308
+ final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
309
+ while cnt < 196:
310
+ out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
311
+ # print(cnt, out.logits.shape)
312
+ past_key_values = out.past_key_values
313
+ if cnt == 0:
314
+ logits = out.logits[:,self.num_prompts:]
315
+ logits[:,:, 50257:] = -1e4
316
+
317
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
318
+
319
+ # print(output.shape)
320
+
321
+ final_prediction = torch.cat([final_prediction, output], dim=1)
322
+ # print(output.shape)
323
+ embeddings = self.model.transformer.wte(output)
324
+ # print(embeddings.shape)
325
+
326
+
327
+ else:
328
+ # print(logits.shape)
329
+ logits = out.logits
330
+ logits[:, :, 50257:] = -1e4
331
+
332
+ output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
333
+ # print(output)
334
+ final_prediction = torch.cat([final_prediction, output], dim=1)
335
+ # print(final_prediction.shape, 'final')
336
+ embeddings = self.model.transformer.wte(output)
337
+
338
+
339
+
340
+ cnt += 1
341
+ # print(output.shape, self.eot.shape)
342
+ if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
343
+ break
344
+
345
+ return final_prediction
346
+
347
+ def zero_after_x(arr, x):
348
+ """
349
+ Zeros out all elements in each row of a 2D tensor after the first occurrence of x.
350
+
351
+ Args:
352
+ tensor: The input 2D tensor.
353
+ x: The value after which to zero out elements.
354
+
355
+ Returns:
356
+ A new tensor with elements zeroed out after x.
357
+ """
358
+
359
+ mask = (arr == x).cumsum(dim=1) > 0 # Create a cumulative mask
360
+ result = torch.where(mask, x, arr) #zero out where mask is True
361
+
362
+ return result
363
+
364
+ class LitModelPromptTuning(L.LightningModule):
365
+ def __init__(self, model, temperature, epoch, lr=1e-4, **kwargs):
366
+ super().__init__()
367
+ self.model = model
368
+ self.lr = lr
369
+ self.model.temperature = temperature
370
+ self.epoch = epoch
371
+ self.temperature = temperature
372
+
373
+ for key, value in kwargs.items():
374
+ setattr(self, key, value)
375
+
376
+ tokenize_to_strings = lambda text: self.model.tokenizer.convert_ids_to_tokens(self.model.tokenizer(text)["input_ids"])
377
+ self.rouge = ROUGEScore(tokenizer=tokenize_to_strings)
378
+
379
+ self.save_hyperparameters(ignore=['model'])
380
+
381
+
382
+ def training_step(self, batch, batch_idx):
383
+ X, y = batch
384
+ # for i,j in zip(X[1], y[1]):
385
+ # print(i.item(),j.item())
386
+
387
+ logits = self.model(X, y)
388
+
389
+ logits[:,:, 50257:] = -1e4
390
+ # print(X.shape, y.shape, logits.shape)
391
+ loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
392
+ # print(loss)
393
+ # prob = logits.softmax(dim=-1)[0,-300:]
394
+ # target = y[0, -300:]
395
+ # print('logits',logits[0,-300:][torch.arange(target.numel()), target])
396
+ # print(prob[torch.arange(target.numel()), target])
397
+ # print(prob.argmax(dim=-1), target, X[0, -300:])
398
+ # print(self.model.tokenizer.decode(prob.argmax(dim=-1)), 'gap', self.model.tokenizer.decode(target))
399
+ # print(self.model.tokenizer.decode(X[0,-300:]))
400
+ # x = F.cross_entropy(logits[0,:-1,:].reshape(-1, logits.shape[-1]), target=y[0,:-1].reshape(-1), ignore_index=50257, reduction='none')
401
+ # print(x[-20:])
402
+ # print(self.model.pad)
403
+
404
+ # print(X[0, -300:].shape, target.shape, prob.argmax(dim=-1).shape)
405
+ # for i,j,k in zip(X[0, -300:], target, prob.argmax(dim=-1)):
406
+ # print(self.model.tokenizer.decode(i),'\tx ',self.model.tokenizer.decode(j),'\tx ',self.model.tokenizer.decode(k))
407
+
408
+
409
+ # exit()
410
+
411
+ self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
412
+ return loss
413
+
414
+
415
+ def validation_step(self, batch, batch_idx):
416
+ X, y = batch
417
+
418
+ logits = self.model(X, y)
419
+ logits[:,:, 50257:] = -1e4
420
+ # print(X.shape, y.shape, logits.shape)
421
+ loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
422
+
423
+ self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
424
+ return loss
425
+
426
+ def on_test_epoch_start(self, ):
427
+ self.all_text = []
428
+ self.predicted_text = []
429
+
430
+ def test_step(self, batch, batch_idx):
431
+ if batch_idx == 0:
432
+ return
433
+ X, y = batch
434
+ # print(self.model.tokenizer.batch_decode(X))
435
+ # print(X.shape)
436
+ # with torch.no_grad()
437
+ out = self.model.generate(X)
438
+ # out = zero_after_x(out, self.model.eot.to(X.device))
439
+ # print(out.shape, y.shape)
440
+ # print(out, y)
441
+ pred = self.model.tokenizer.batch_decode(out, skip_special_tokens=False)
442
+ gt = self.model.tokenizer.batch_decode(y, skip_special_tokens=False)
443
+
444
+
445
+ print(pred)
446
+ print('GAP')
447
+ print(gt)
448
+ final_score = 0
449
+
450
+ for p,g in zip(pred, gt):
451
+ score = self.rouge(p, g, )
452
+ print(score)
453
+ # exit()
454
+
455
+
456
+ self.log_dict(score, on_step=True, on_epoch=True, logger=True, sync_dist=True)
457
+
458
+
459
+
460
+ def configure_optimizers(self):
461
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
462
+ return optimizer
463
+
464
+ from lightning.pytorch.loggers import WandbLogger
465
+ if __name__ == '__main__':
466
+ train = False
467
+
468
+ torch.set_float32_matmul_precision('medium')
469
+ dl_train, dl_val, dl_test = utils.import_data(bs=24, fraction=1)
470
+ # gpt_model = PromptTuningModel(num_prompts=24)
471
+ if train:
472
+ gpt_model = LoraModel(dim=16)
473
+ else:
474
+ gpt_model = torch.load('./model1.bin')
475
+
476
+ # gpt_model = torch.compile(gpt_model)
477
+ model = LitModelPromptTuning(
478
+ model=gpt_model,
479
+ lr=1e-3,
480
+ temperature=0.9,
481
+ epoch = 5,
482
+
483
+ type_model = 'lora',
484
+ dimension = 16
485
+ )
486
+ print('Training')
487
+
488
+ logger = WandbLogger(project='Anlp-3')
489
+ trainer = L.Trainer(
490
+ accelerator='gpu',
491
+ # limit_train_batches=1,
492
+ # strategy='auto',
493
+ # strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
494
+ devices=1,
495
+ default_root_dir=f'./logs/', # Tensorflow can be used to viz
496
+ num_nodes=1,
497
+ num_sanity_val_steps=1, # runs a validation step before stating training
498
+ precision='bf16-mixed', # we use half precision to reduce memory usage
499
+ max_epochs=5,
500
+ check_val_every_n_epoch=1, # run validation every epoch
501
+ log_every_n_steps=20,
502
+ logger=logger,
503
+ # detect_anomaly=True,
504
+ )
505
+
506
+ if train:
507
+ trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
508
+ trainer.test(model, dataloaders=dl_test)
509
+ torch.save(model.model, './model2.bin')
510
+ else:
511
+ trainer.test(model, dataloaders=dl_test)
model0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5411f279c6809268001997e4d5ea65eb050828c703d82a21066eb5f2370f01e3
3
+ size 512094683
model1.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38409a7f9c3348446f062d5b616f603770976c94e0c789f7baebf28005c87674
3
+ size 511946721
model2.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e74693de89775360e55273ca457bfc44c1396070076530fb272bf5a355aa2a2c
3
+ size 514335129
newspaper-text-summarization-cnn-dailymail.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f9e0cad39333d1c9f8902be6d846000ef1ccd5e994ad3c42f53336270ab8611
3
+ size 527738644
utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import inf
2
+ # from utils import *
3
+ from h11 import Data
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ import torch.utils
8
+ import torch.utils.data
9
+
10
+ from torch.utils.data import DataLoader, Dataset
11
+ # from utils import MyDataset, custom_collate
12
+ from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
13
+ import wandb
14
+ import torch.nn.functional as F
15
+
16
+ import einops
17
+ import pandas as pd
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer
19
+ from transformers import GPT2TokenizerFast
20
+ import os
21
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
+ import re
23
+
24
+ class CNNDataset(Dataset):
25
+ def __init__(self, df, max_length = 1000, max_len=21000, test_ds=False):
26
+ super().__init__()
27
+ self.df = df
28
+ self.max_len = max_len
29
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
30
+ self.max_length = max_length
31
+ self.test_ds = test_ds
32
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
33
+ self.tokenizer.add_special_tokens({'cls_token': '[START]'})
34
+
35
+ self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
36
+ self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
37
+ self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
38
+ print(len(self.tokenizer), self.start, "Pad")
39
+ for index in range(max_len):
40
+ x, y = self.df['article'][index], self.df['highlights'][index]
41
+ x, y = re.sub(r'[\t\n\r]', ' ', x) , re.sub(r'[\t\n\r]', ' ', y)
42
+ y = self.tokenizer(y, return_tensors="pt", max_length=256,truncation=True).input_ids[0]
43
+ x = self.tokenizer(x, return_tensors="pt", max_length=self.max_length-max(y.shape[0], 256+24), truncation=True).input_ids[0]
44
+ self.df.loc[index, 'article'], self.df.loc[index, 'highlights'] = x,y
45
+
46
+ def __len__(self, ):
47
+ return self.max_len
48
+
49
+ def __getitem__(self, index):
50
+ x, y = self.df['article'][index], self.df['highlights'][index]
51
+
52
+ # Check if middle self.eot is needed
53
+ # print(x, self.eot)
54
+ if self.test_ds:
55
+ return torch.cat([self.eot, x, self.start]), torch.cat([y, self.eot])
56
+ x = torch.cat([self.eot, x, self.start, y, self.eot])
57
+ y = torch.cat([y, self.eot])
58
+
59
+ y_final = torch.ones(x.shape[0], dtype=torch.long)
60
+ y_final[-y.shape[0]-1:-1] = y
61
+ y_final[:-y.shape[0]-1] = self.pad
62
+ return x, y_final
63
+
64
+ def properly_pad(context):
65
+ lenghts = []
66
+ # print(context)
67
+ for i in context:
68
+ lenghts.append(i.shape[0])
69
+ lenghts = torch.tensor(lenghts)
70
+
71
+ ind = torch.argsort(lenghts, descending=True)
72
+ lenghts = lenghts[ind]
73
+
74
+ sorted_tensors = [context[i] for i in ind]
75
+
76
+ context = sorted_tensors
77
+ context = pad_sequence(sequences=context, batch_first=True, padding_value=50257)
78
+
79
+ return context
80
+
81
+ def custom_collate(batch):
82
+ # print(batch)
83
+ context, target = [], []
84
+ # print(batch)
85
+ for a,b in batch:
86
+ context.append(a)
87
+ target.append(b)
88
+
89
+ context, target = properly_pad(context), properly_pad(target)
90
+
91
+ return context, target
92
+
93
+ def import_data(bs=4, fraction=0.1):
94
+ df_train = pd.read_csv('./cnn_dailymail/train.csv')
95
+ df_val = pd.read_csv('./cnn_dailymail/validation.csv')
96
+ df_test = pd.read_csv('./cnn_dailymail/test.csv')
97
+
98
+ print('Loaded data')
99
+
100
+ df_train, df_val, df_test = CNNDataset(df_train, max_len=int(21000*fraction)), CNNDataset(df_val, max_len=int(fraction*6000)), CNNDataset(df_test, max_len=int(fraction*300), test_ds=True)
101
+
102
+ df_train = DataLoader(df_train, batch_size=bs, num_workers=7, collate_fn=custom_collate)
103
+ df_test = DataLoader(df_test, batch_size=1, num_workers=7, collate_fn=custom_collate)
104
+ df_val = DataLoader(df_val, batch_size=bs, num_workers=7, collate_fn=custom_collate)
105
+
106
+ # print(df_train['article'][0])
107
+ return df_train, df_val, df_test
108
+
109
+ if __name__ == '__main__':
110
+ tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
111
+
112
+ tokenizer.add_special_tokens({'cls_token': '[START]'})
113
+
114
+ eot =tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
115
+ pad =tokenizer("[PAD]", return_tensors="pt").input_ids[0]
116
+ start =tokenizer("[START]", return_tensors="pt").input_ids[0]
117
+
118
+ print(tokenizer.decode([1, 2, 50256]))
119
+ print(tokenizer.decode([1, 2, 50257]))
120
+ print(tokenizer('[START]'))
121
+ # dl_train, dl_val, dl_test = import_data()
122
+ # for x,y in dl_train:
123
+ # print(x.shape, y.shape)
124
+ # break
wandb/debug-internal.log ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2024-10-29T19:35:07.322742046+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
2
+ {"time":"2024-10-29T19:35:07.322754356+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241029_193507-ujpie5pz/logs/debug-core.log"}
3
+ {"time":"2024-10-29T19:35:07.323074973+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
4
+ {"time":"2024-10-29T19:35:07.323078902+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241029_193507-ujpie5pz/logs/debug-core.log"}
5
+ {"time":"2024-10-29T19:35:07.327151125+04:00","level":"INFO","msg":"created new stream","id":"ujpie5pz"}
6
+ {"time":"2024-10-29T19:35:07.327170485+04:00","level":"INFO","msg":"stream: started","id":"ujpie5pz"}
7
+ {"time":"2024-10-29T19:35:07.327189255+04:00","level":"INFO","msg":"handler: started","stream_id":{"value":"ujpie5pz"}}
8
+ {"time":"2024-10-29T19:35:07.327206894+04:00","level":"INFO","msg":"sender: started","stream_id":{"value":"ujpie5pz"}}
9
+ {"time":"2024-10-29T19:35:07.327237104+04:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"ujpie5pz"}}
10
+ {"time":"2024-10-29T19:35:07.98503491+04:00","level":"INFO","msg":"wandb-core","!BADKEY":null}
11
+ {"time":"2024-10-29T19:35:07.98717852+04:00","level":"INFO","msg":"Starting system monitor"}
12
+ {"time":"2024-10-29T19:35:07.988807925+04:00","level":"ERROR","msg":"git repo not found","error":"repository does not exist"}
13
+ {"time":"2024-10-29T19:37:46.533074708+04:00","level":"INFO","msg":"stream: closing","id":"ujpie5pz"}
14
+ {"time":"2024-10-29T19:37:46.533114998+04:00","level":"INFO","msg":"Stopping system monitor"}
15
+ {"time":"2024-10-29T19:37:46.534237927+04:00","level":"INFO","msg":"Stopped system monitor"}
16
+ {"time":"2024-10-29T19:37:46.799687608+04:00","level":"WARN","msg":"No job ingredients found, not creating job artifact"}
17
+ {"time":"2024-10-29T19:37:46.799694428+04:00","level":"WARN","msg":"No source type found, not creating job artifact"}
18
+ {"time":"2024-10-29T19:37:46.799698199+04:00","level":"INFO","msg":"sender: sendDefer: no job artifact to save"}
19
+ {"time":"2024-10-29T19:37:48.809218785+04:00","level":"INFO","msg":"handler: closed","stream_id":{"value":"ujpie5pz"}}
20
+ {"time":"2024-10-29T19:37:48.809270345+04:00","level":"INFO","msg":"sender: closed","stream_id":{"value":"ujpie5pz"}}
21
+ {"time":"2024-10-29T19:37:48.809270115+04:00","level":"INFO","msg":"writer: Close: closed","stream_id":{"value":"ujpie5pz"}}
22
+ {"time":"2024-10-29T19:37:48.809871399+04:00","level":"INFO","msg":"stream: closed","id":"ujpie5pz"}
wandb/debug.log ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Current SDK version is 0.18.1
2
+ 2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Configure stats pid to 1599827
3
+ 2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/.config/wandb/settings
4
+ 2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/Kyrylo/nlp/wandb/settings
5
+ 2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Loading settings from environment variables: {}
6
+ 2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Applying setup settings: {'mode': None, '_disable_service': None}
7
+ 2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Inferring run settings from compute environment: {'program_relpath': 'main2.py', 'program_abspath': '/home/siddharth.tourani/Kyrylo/nlp/main2.py', 'program': '/home/siddharth.tourani/Kyrylo/nlp/main2.py'}
8
+ 2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Applying login settings: {}
9
+ 2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_init.py:_log_setup():532] Logging user logs to ./wandb/run-20241029_193507-ujpie5pz/logs/debug.log
10
+ 2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_init.py:_log_setup():533] Logging internal logs to ./wandb/run-20241029_193507-ujpie5pz/logs/debug-internal.log
11
+ 2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_init.py:init():616] calling init triggers
12
+ 2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_init.py:init():623] wandb.init called with sweep_config: {}
13
+ config: {}
14
+ 2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_init.py:init():666] starting backend
15
+ 2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_init.py:init():670] setting up manager
16
+ 2024-10-29 19:35:07,319 INFO MainThread:1599827 [backend.py:_multiprocessing_setup():105] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
17
+ 2024-10-29 19:35:07,320 INFO MainThread:1599827 [wandb_init.py:init():678] backend started and connected
18
+ 2024-10-29 19:35:07,322 INFO MainThread:1599827 [wandb_init.py:init():773] updated telemetry
19
+ 2024-10-29 19:35:07,323 INFO MainThread:1599827 [wandb_init.py:init():806] communicating run to backend with 90.0 second timeout
20
+ 2024-10-29 19:35:07,981 INFO MainThread:1599827 [wandb_init.py:init():857] starting run threads in backend
21
+ 2024-10-29 19:35:08,154 INFO MainThread:1599827 [wandb_run.py:_console_start():2459] atexit reg
22
+ 2024-10-29 19:35:08,154 INFO MainThread:1599827 [wandb_run.py:_redirect():2307] redirect: wrap_raw
23
+ 2024-10-29 19:35:08,154 INFO MainThread:1599827 [wandb_run.py:_redirect():2372] Wrapping output streams.
24
+ 2024-10-29 19:35:08,154 INFO MainThread:1599827 [wandb_run.py:_redirect():2397] Redirects installed.
25
+ 2024-10-29 19:35:08,156 INFO MainThread:1599827 [wandb_init.py:init():900] run started, returning control to user process
26
+ 2024-10-29 19:35:08,447 INFO MainThread:1599827 [wandb_run.py:_config_callback():1388] config_cb None None {'temperature': 0.9, 'epoch': 5, 'lr': 0.001, 'type_model': 'lora', 'dimension': 16}
27
+ 2024-10-29 19:37:46,533 WARNING MsgRouterThr:1599827 [router.py:message_loop():77] message_loop has been closed
wandb/run-20241028_085547-mga40p7t/files/code/main.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import inf
2
+ # from utils import *
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import torch.utils
7
+ import torch.utils.data
8
+ # from utils import MyDataset, custom_collate
9
+ from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
10
+ import wandb
11
+ import torch.nn.functional as F
12
+
13
+ import einops
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
15
+
16
+ np.random.seed(123)
17
+ torch.manual_seed(123)
18
+ torch.cuda.random.manual_seed(123)
19
+
20
+ import lightning as L
21
+ import utils
22
+
23
+
24
+ class PromptTuningModel(nn.Module):
25
+ def __init__(self, num_prompts=6):
26
+ super().__init__()
27
+ self.num_prompts = num_prompts
28
+
29
+ self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
30
+ self.model.requires_grad_(False)
31
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
32
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
33
+
34
+ self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
35
+ self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
36
+
37
+ self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer), pad_to_multiple_of=128)
38
+
39
+ tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
40
+ token_embedding = self.model.transformer.wte(tmp[0])
41
+ self.token_embedding = token_embedding
42
+ for _ in range(num_prompts//3-1):
43
+ self.token_embedding = torch.cat([self.token_embedding, token_embedding])
44
+
45
+ # print(self.token_embedding.shape)
46
+ data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
47
+ self.learnable_prompt = nn.Parameter(data, requires_grad=True)
48
+
49
+ # @torch.compile
50
+ def forward(self, X):
51
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
52
+ embeddings = self.model.transformer.wte(X, ) # b s d
53
+
54
+ embeddings = torch.cat([embeddings, self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1)], dim=1)
55
+ mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
56
+ # print(mask.shape)
57
+ logits = self.model(inputs_embeds = embeddings, attention_mask=mask).logits[:,self.num_prompts:].swapaxes(1,2)
58
+ return logits
59
+
60
+ class LitModelPromptTuning(L.LightningModule):
61
+ def __init__(self, model, lr=1e-4):
62
+ super().__init__()
63
+ self.model = model
64
+ self.lr = lr
65
+
66
+ self.save_hyperparameters(ignore=['model'])
67
+
68
+
69
+ def training_step(self, batch, batch_idx):
70
+ X, y = batch
71
+ # for i,j in zip(X[0], y[0]):
72
+ # print(i.item(),j.item())
73
+ # print(self.model.pad, self.model.eot)
74
+ logits = self.model(X)
75
+ loss = F.cross_entropy(logits, target=y, ignore_index=50257)
76
+ # print(loss)
77
+ # exit()
78
+
79
+ self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
80
+ return loss
81
+
82
+
83
+ def validation_step(self, batch, batch_idx):
84
+ X, y = batch
85
+
86
+ logits = self.model(X)
87
+ loss = F.cross_entropy(logits, target=y, ignore_index=50257)
88
+
89
+ self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
90
+ return loss
91
+
92
+
93
+
94
+ def configure_optimizers(self):
95
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
96
+ return optimizer
97
+
98
+ from lightning.pytorch.loggers import WandbLogger
99
+ if __name__ == '__main__':
100
+ torch.set_float32_matmul_precision('medium')
101
+ dl_train, dl_val, dl_test = utils.import_data(bs=5, fraction=0.1)
102
+ gpt_model = PromptTuningModel()
103
+ # gpt_model = torch.compile(gpt_model)
104
+ model = LitModelPromptTuning(model=gpt_model)
105
+ print('Training')
106
+
107
+ logger = WandbLogger(project='Anlp-3')
108
+ trainer = L.Trainer(
109
+ accelerator='gpu',
110
+ # strategy='auto',
111
+ # strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
112
+ devices=[3],
113
+ default_root_dir=f'./logs/', # Tensorflow can be used to viz
114
+ num_nodes=1,
115
+ num_sanity_val_steps=1, # runs a validation step before stating training
116
+ precision='16-mixed', # we use half precision to reduce memory usage
117
+ max_epochs=10,
118
+ check_val_every_n_epoch=1, # run validation every epoch
119
+ log_every_n_steps=20,
120
+ logger=logger
121
+ # detect_anomaly=True,
122
+ )
123
+
124
+ trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
wandb/run-20241028_085547-mga40p7t/files/config.yaml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _wandb:
2
+ value:
3
+ cli_version: 0.18.1
4
+ code_path: code/main.py
5
+ m:
6
+ - "1": trainer/global_step
7
+ "6":
8
+ - 3
9
+ "7": []
10
+ - "1": Training loss_step
11
+ "5": 1
12
+ "6":
13
+ - 1
14
+ - 3
15
+ "7": []
16
+ - "1": epoch
17
+ "5": 1
18
+ "6":
19
+ - 1
20
+ - 3
21
+ "7": []
22
+ - "1": Validation loss_step
23
+ "5": 1
24
+ "6":
25
+ - 1
26
+ - 3
27
+ "7": []
28
+ - "1": Validation loss_epoch
29
+ "5": 1
30
+ "6":
31
+ - 1
32
+ - 3
33
+ "7": []
34
+ - "1": Training loss_epoch
35
+ "5": 1
36
+ "6":
37
+ - 1
38
+ - 3
39
+ "7": []
40
+ python_version: 3.10.13
41
+ t:
42
+ "1":
43
+ - 1
44
+ - 11
45
+ - 49
46
+ - 55
47
+ - 71
48
+ - 106
49
+ "2":
50
+ - 1
51
+ - 11
52
+ - 49
53
+ - 55
54
+ - 71
55
+ - 106
56
+ "3":
57
+ - 7
58
+ - 23
59
+ - 55
60
+ - 66
61
+ "4": 3.10.13
62
+ "5": 0.18.1
63
+ "6": 4.44.2
64
+ "8":
65
+ - 5
66
+ "12": 0.18.1
67
+ "13": linux-x86_64
68
+ lr:
69
+ value: 0.0001
wandb/run-20241028_085547-mga40p7t/files/output.log ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
2
+
3
+ | Name | Type | Params | Mode
4
+ ----------------------------------------------------
5
+ 0 | model | PromptTuningModel | 124 M | train
6
+ ----------------------------------------------------
7
+ 4.6 K Trainable params
8
+ 124 M Non-trainable params
9
+ 124 M Total params
10
+ 497.922 Total estimated model params size (MB)
11
+ 1 Modules in train mode
12
+ 164 Modules in eval mode
13
+ Epoch 1: 3%|██▌ | 11/420 [00:02<01:47, 3.79it/s, v_num=0p7t]
14
+
15
+
16
+ Detected KeyboardInterrupt, attempting graceful shutdown ...
wandb/run-20241028_085547-mga40p7t/files/requirements.txt ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ jiter==0.5.0
2
+ anyio==4.6.0
3
+ interegular==0.3.3
4
+ jaxlib==0.4.34
5
+ jsonschema==4.23.0
6
+ typing_extensions==4.12.2
7
+ httpcore==1.0.5
8
+ prometheus_client==0.21.0
9
+ openai==1.51.0
10
+ multidict==6.1.0
11
+ six==1.16.0
12
+ nvidia-nccl-cu12==2.20.5
13
+ nvidia-cuda-runtime-cu12==12.1.105
14
+ nvidia-cuda-cupti-cu12==12.1.105
15
+ nvidia-cudnn-cu12==9.1.0.70
16
+ watchfiles==0.24.0
17
+ tqdm==4.66.5
18
+ yarl==1.11.1
19
+ cffi==1.17.1
20
+ vllm==0.6.1.post2
21
+ bleach==6.1.0
22
+ kaggle==1.6.17
23
+ pydantic_core==2.23.4
24
+ lightning-utilities==0.11.7
25
+ sentry-sdk==2.14.0
26
+ torch==2.4.0
27
+ aiohappyeyeballs==2.4.0
28
+ diffusers==0.15.0
29
+ GitPython==3.1.43
30
+ attrs==24.2.0
31
+ importlib_metadata==8.5.0
32
+ transformers==4.44.2
33
+ pillow==10.4.0
34
+ sounddevice==0.5.1
35
+ gguf==0.9.1
36
+ python-dotenv==1.0.1
37
+ async-timeout==4.0.3
38
+ dspy-ai==2.5.3
39
+ numpy==1.26.4
40
+ nvidia-nvjitlink-cu12==12.6.68
41
+ uvicorn==0.30.6
42
+ kiwisolver==1.4.7
43
+ partial-json-parser==0.2.1.1.post4
44
+ pyparsing==3.1.4
45
+ lightning==2.4.0
46
+ structlog==24.4.0
47
+ nvidia-curand-cu12==10.3.2.106
48
+ setuptools==65.5.0
49
+ webencodings==0.5.1
50
+ nvidia-nvtx-cu12==12.1.105
51
+ sniffio==1.3.1
52
+ MarkupSafe==2.1.5
53
+ vllm-flash-attn==2.6.1
54
+ urllib3==2.2.3
55
+ requests==2.32.3
56
+ pycountry==24.6.1
57
+ ujson==5.10.0
58
+ matplotlib==3.9.2
59
+ pydantic==2.9.2
60
+ torchvision==0.19.0
61
+ numba==0.60.0
62
+ optuna==4.0.0
63
+ opt_einsum==3.4.0
64
+ joblib==1.4.2
65
+ msgpack==1.1.0
66
+ smmap==5.0.1
67
+ filelock==3.16.1
68
+ opencv-contrib-python==4.10.0.84
69
+ faiss-gpu==1.7.2
70
+ prometheus-fastapi-instrumentator==7.0.0
71
+ rpds-py==0.20.0
72
+ psutil==6.0.0
73
+ colorlog==6.8.2
74
+ nvidia-cufft-cu12==11.0.2.54
75
+ SQLAlchemy==2.0.35
76
+ llvmlite==0.43.0
77
+ packaging==24.1
78
+ exceptiongroup==1.2.2
79
+ dill==0.3.8
80
+ ml_dtypes==0.5.0
81
+ pyairports==2.1.1
82
+ scikit-learn==1.5.2
83
+ prettytable==3.11.0
84
+ protobuf==4.25.5
85
+ charset-normalizer==3.3.2
86
+ torchmetrics==1.4.2
87
+ text-unidecode==1.3
88
+ httpx==0.27.2
89
+ sympy==1.13.3
90
+ msgspec==0.18.6
91
+ wandb==0.18.1
92
+ backoff==2.2.1
93
+ sentencepiece==0.2.0
94
+ aiohttp==3.10.5
95
+ distro==1.9.0
96
+ lark==1.2.2
97
+ pyarrow==17.0.0
98
+ Mako==1.3.5
99
+ regex==2024.9.11
100
+ safetensors==0.4.5
101
+ aiosignal==1.3.1
102
+ jsonschema-specifications==2023.12.1
103
+ cloudpickle==3.0.0
104
+ einops==0.8.0
105
+ ray==2.36.1
106
+ fire==0.7.0
107
+ pyzmq==26.2.0
108
+ pycparser==2.22
109
+ platformdirs==4.3.6
110
+ click==8.1.7
111
+ fastapi==0.115.0
112
+ ftfy==6.3.0
113
+ torchtext==0.18.0
114
+ lm-format-enforcer==0.10.6
115
+ fsspec==2024.6.1
116
+ tzdata==2024.2
117
+ starlette==0.38.6
118
+ cycler==0.12.1
119
+ py-cpuinfo==9.0.0
120
+ h11==0.14.0
121
+ huggingface-hub==0.25.1
122
+ nvidia-cusparse-cu12==12.1.0.106
123
+ nvidia-ml-py==12.560.30
124
+ certifi==2024.8.30
125
+ httptools==0.6.1
126
+ jax==0.4.34
127
+ PyYAML==6.0.2
128
+ xxhash==3.5.0
129
+ idna==3.10
130
+ xformers==0.0.27.post2
131
+ mistral_common==1.4.3
132
+ fonttools==4.54.0
133
+ pip==23.0.1
134
+ accelerate==0.34.2
135
+ mediapipe==0.10.15
136
+ pytorch-lightning==2.4.0
137
+ ollama==0.3.3
138
+ Jinja2==3.1.4
139
+ multiprocess==0.70.16
140
+ opencv-python==4.10.0.84
141
+ termcolor==2.5.0
142
+ python-dateutil==2.9.0.post0
143
+ contourpy==1.3.0
144
+ websockets==13.1
145
+ frozenlist==1.4.1
146
+ pandas==2.2.3
147
+ networkx==3.3
148
+ diskcache==5.6.3
149
+ nvidia-cusolver-cu12==11.4.5.107
150
+ flatbuffers==24.3.25
151
+ mpmath==1.3.0
152
+ setproctitle==1.3.3
153
+ tokenizers==0.19.1
154
+ scipy==1.14.1
155
+ outlines==0.0.46
156
+ annotated-types==0.7.0
157
+ docker-pycreds==0.4.0
158
+ magicattr==0.1.6
159
+ wcwidth==0.2.13
160
+ pytorch-metric-learning==2.6.0
161
+ datasets==3.0.0
162
+ gitdb==4.0.11
163
+ lora-diffusion==0.1.7
164
+ referencing==0.35.1
165
+ python-slugify==8.0.4
166
+ zipp==3.20.2
167
+ triton==3.0.0
168
+ absl-py==2.1.0
169
+ threadpoolctl==3.5.0
170
+ uvloop==0.20.0
171
+ tiktoken==0.7.0
172
+ pytz==2024.2
173
+ nest-asyncio==1.6.0
174
+ nvidia-cublas-cu12==12.1.3.1
175
+ litellm==1.48.12
176
+ nvidia-cuda-nvrtc-cu12==12.1.105
177
+ greenlet==3.1.1
178
+ alembic==1.13.3
wandb/run-20241028_085547-mga40p7t/files/wandb-metadata.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.15.161-ql-generic-13.0-14-x86_64-with-glibc2.35",
3
+ "python": "3.10.13",
4
+ "startedAt": "2024-10-28T04:55:47.033649Z",
5
+ "program": "/home/siddharth.tourani/Kyrylo/nlp/main.py",
6
+ "codePath": "main.py",
7
+ "email": "kyrylo.shyvam@students.iiit.ac.in",
8
+ "root": ".",
9
+ "host": "gpu-08",
10
+ "username": "siddharth.tourani",
11
+ "executable": "/home/siddharth.tourani/Minimal/bin/python3",
12
+ "codePathLocal": "main.py",
13
+ "cpu_count": 128,
14
+ "cpu_count_logical": 256,
15
+ "gpu": "[NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB]",
16
+ "gpu_count": 4,
17
+ "disk": {
18
+ "/": {
19
+ "total": "1073741824",
20
+ "used": "21049344"
21
+ }
22
+ },
23
+ "memory": {
24
+ "total": "270256893952"
25
+ },
26
+ "cpu": {
27
+ "count": 128,
28
+ "countLogical": 256
29
+ },
30
+ "gpu_nvidia": [
31
+ {
32
+ "name": "NVIDIA A100-SXM4-40GB",
33
+ "memoryTotal": "42949672960",
34
+ "cudaCores": 6912,
35
+ "architecture": "Ampere"
36
+ },
37
+ {
38
+ "name": "NVIDIA A100-SXM4-40GB",
39
+ "memoryTotal": "42949672960",
40
+ "cudaCores": 6912,
41
+ "architecture": "Ampere"
42
+ },
43
+ {
44
+ "name": "NVIDIA A100-SXM4-40GB",
45
+ "memoryTotal": "42949672960",
46
+ "cudaCores": 6912,
47
+ "architecture": "Ampere"
48
+ },
49
+ {
50
+ "name": "NVIDIA A100-SXM4-40GB",
51
+ "memoryTotal": "42949672960",
52
+ "cudaCores": 6912,
53
+ "architecture": "Ampere"
54
+ }
55
+ ],
56
+ "slurm": {
57
+ "job_id": "68153"
58
+ },
59
+ "cudaVersion": "12.4"
60
+ }
wandb/run-20241028_085547-mga40p7t/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"epoch":0,"Training loss_epoch":107.80526733398438,"_step":142,"_runtime":102.656591521,"Validation loss_epoch":107.9881820678711,"Training loss_step":108.02771759033203,"_timestamp":1.730091449689934e+09,"trainer/global_step":419,"Validation loss_step":108.47615051269531,"_wandb":{"runtime":105}}
wandb/run-20241028_085547-mga40p7t/logs/debug-core.log ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2024-10-28T08:55:46.356024757+04:00","level":"INFO","msg":"started logging, with flags","port-filename":"/tmp/tmptnuwuzf_/port-1239188.txt","pid":1239188,"debug":false,"disable-analytics":false}
2
+ {"time":"2024-10-28T08:55:46.356050947+04:00","level":"INFO","msg":"FeatureState","shutdownOnParentExitEnabled":false}
3
+ {"time":"2024-10-28T08:55:46.35809675+04:00","level":"INFO","msg":"Will exit if parent process dies.","ppid":1239188}
4
+ {"time":"2024-10-28T08:55:46.35809731+04:00","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":36381,"Zone":""}}
5
+ {"time":"2024-10-28T08:55:46.482841457+04:00","level":"INFO","msg":"created new connection","id":"127.0.0.1:59774"}
6
+ {"time":"2024-10-28T08:55:47.034493884+04:00","level":"INFO","msg":"connection init received","streamId":"mga40p7t","id":"127.0.0.1:59774"}
7
+ {"time":"2024-10-28T08:55:47.035343787+04:00","level":"ERROR","msg":"error creating symlink","error":"symlink /home/siddharth.tourani/.cache/wandb/logs/core-debug-20241028_085546.log wandb/run-20241028_085547-mga40p7t/logs/debug-core.log: file exists"}
8
+ {"time":"2024-10-28T08:55:47.044292544+04:00","level":"INFO","msg":"connection init completed","streamId":"mga40p7t","id":"127.0.0.1:59774"}
9
+ {"time":"2024-10-28T08:57:32.990052624+04:00","level":"INFO","msg":"connection: teardown","id":"127.0.0.1:59774"}
10
+ {"time":"2024-10-28T08:57:32.990356191+04:00","level":"INFO","msg":"server is shutting down"}
11
+ {"time":"2024-10-28T08:57:32.990407101+04:00","level":"INFO","msg":"closed connection","id":"127.0.0.1:59774"}
12
+ {"time":"2024-10-28T08:57:33.264727032+04:00","level":"ERROR","msg":"error flushing writer","err":"write tcp 127.0.0.1:36381->127.0.0.1:59774: use of closed network connection","id":"127.0.0.1:59774"}
13
+ {"time":"2024-10-28T08:57:34.392709437+04:00","level":"INFO","msg":"connection closed","id":"127.0.0.1:59774"}
14
+ {"time":"2024-10-28T08:57:34.392722266+04:00","level":"INFO","msg":"server is closed"}
wandb/run-20241028_085547-mga40p7t/logs/debug-internal.log ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2024-10-28T08:55:47.035188748+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
2
+ {"time":"2024-10-28T08:55:47.035200088+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241028_085547-mga40p7t/logs/debug-core.log"}
3
+ {"time":"2024-10-28T08:55:47.035571516+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
4
+ {"time":"2024-10-28T08:55:47.035578935+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241028_085547-mga40p7t/logs/debug-core.log"}
5
+ {"time":"2024-10-28T08:55:47.044272764+04:00","level":"INFO","msg":"created new stream","id":"mga40p7t"}
6
+ {"time":"2024-10-28T08:55:47.044289244+04:00","level":"INFO","msg":"stream: started","id":"mga40p7t"}
7
+ {"time":"2024-10-28T08:55:47.044306184+04:00","level":"INFO","msg":"handler: started","stream_id":{"value":"mga40p7t"}}
8
+ {"time":"2024-10-28T08:55:47.044329513+04:00","level":"INFO","msg":"sender: started","stream_id":{"value":"mga40p7t"}}
9
+ {"time":"2024-10-28T08:55:47.044370663+04:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"mga40p7t"}}
10
+ {"time":"2024-10-28T08:55:47.744002406+04:00","level":"INFO","msg":"wandb-core","!BADKEY":null}
11
+ {"time":"2024-10-28T08:55:47.747442758+04:00","level":"INFO","msg":"Starting system monitor"}
12
+ {"time":"2024-10-28T08:55:47.748927455+04:00","level":"ERROR","msg":"git repo not found","error":"repository does not exist"}
13
+ {"time":"2024-10-28T08:57:32.990331552+04:00","level":"INFO","msg":"stream: closing","id":"mga40p7t"}
14
+ {"time":"2024-10-28T08:57:32.990407231+04:00","level":"INFO","msg":"Stopping system monitor"}
15
+ {"time":"2024-10-28T08:57:32.992264416+04:00","level":"INFO","msg":"Stopped system monitor"}
16
+ {"time":"2024-10-28T08:57:33.519506024+04:00","level":"WARN","msg":"No job ingredients found, not creating job artifact"}
17
+ {"time":"2024-10-28T08:57:33.519518634+04:00","level":"WARN","msg":"No source type found, not creating job artifact"}
18
+ {"time":"2024-10-28T08:57:33.519521734+04:00","level":"INFO","msg":"sender: sendDefer: no job artifact to save"}
19
+ {"time":"2024-10-28T08:57:34.391853104+04:00","level":"INFO","msg":"handler: closed","stream_id":{"value":"mga40p7t"}}
20
+ {"time":"2024-10-28T08:57:34.391906283+04:00","level":"INFO","msg":"writer: Close: closed","stream_id":{"value":"mga40p7t"}}
21
+ {"time":"2024-10-28T08:57:34.391911083+04:00","level":"INFO","msg":"sender: closed","stream_id":{"value":"mga40p7t"}}
22
+ {"time":"2024-10-28T08:57:34.392637247+04:00","level":"INFO","msg":"stream: closed","id":"mga40p7t"}
wandb/run-20241028_085547-mga40p7t/logs/debug.log ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Current SDK version is 0.18.1
2
+ 2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Configure stats pid to 1239188
3
+ 2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/.config/wandb/settings
4
+ 2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/Kyrylo/nlp/wandb/settings
5
+ 2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Loading settings from environment variables: {}
6
+ 2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Applying setup settings: {'mode': None, '_disable_service': None}
7
+ 2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Inferring run settings from compute environment: {'program_relpath': 'main.py', 'program_abspath': '/home/siddharth.tourani/Kyrylo/nlp/main.py', 'program': '/home/siddharth.tourani/Kyrylo/nlp/main.py'}
8
+ 2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Applying login settings: {}
9
+ 2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_init.py:_log_setup():532] Logging user logs to ./wandb/run-20241028_085547-mga40p7t/logs/debug.log
10
+ 2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_init.py:_log_setup():533] Logging internal logs to ./wandb/run-20241028_085547-mga40p7t/logs/debug-internal.log
11
+ 2024-10-28 08:55:47,030 INFO MainThread:1239188 [wandb_init.py:init():616] calling init triggers
12
+ 2024-10-28 08:55:47,030 INFO MainThread:1239188 [wandb_init.py:init():623] wandb.init called with sweep_config: {}
13
+ config: {}
14
+ 2024-10-28 08:55:47,030 INFO MainThread:1239188 [wandb_init.py:init():666] starting backend
15
+ 2024-10-28 08:55:47,030 INFO MainThread:1239188 [wandb_init.py:init():670] setting up manager
16
+ 2024-10-28 08:55:47,030 INFO MainThread:1239188 [backend.py:_multiprocessing_setup():105] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
17
+ 2024-10-28 08:55:47,033 INFO MainThread:1239188 [wandb_init.py:init():678] backend started and connected
18
+ 2024-10-28 08:55:47,036 INFO MainThread:1239188 [wandb_init.py:init():773] updated telemetry
19
+ 2024-10-28 08:55:47,036 INFO MainThread:1239188 [wandb_init.py:init():806] communicating run to backend with 90.0 second timeout
20
+ 2024-10-28 08:55:47,740 INFO MainThread:1239188 [wandb_init.py:init():857] starting run threads in backend
21
+ 2024-10-28 08:55:47,953 INFO MainThread:1239188 [wandb_run.py:_console_start():2459] atexit reg
22
+ 2024-10-28 08:55:47,953 INFO MainThread:1239188 [wandb_run.py:_redirect():2307] redirect: wrap_raw
23
+ 2024-10-28 08:55:47,953 INFO MainThread:1239188 [wandb_run.py:_redirect():2372] Wrapping output streams.
24
+ 2024-10-28 08:55:47,953 INFO MainThread:1239188 [wandb_run.py:_redirect():2397] Redirects installed.
25
+ 2024-10-28 08:55:47,957 INFO MainThread:1239188 [wandb_init.py:init():900] run started, returning control to user process
26
+ 2024-10-28 08:55:48,061 INFO MainThread:1239188 [wandb_run.py:_config_callback():1388] config_cb None None {'lr': 0.0001}
27
+ 2024-10-28 08:57:32,990 WARNING MsgRouterThr:1239188 [router.py:message_loop():77] message_loop has been closed
wandb/run-20241028_085547-mga40p7t/run-mga40p7t.wandb ADDED
Binary file (500 kB). View file
 
wandb/run-20241028_085806-owcrwbil/files/code/main.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import inf
2
+ # from utils import *
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import torch.utils
7
+ import torch.utils.data
8
+ # from utils import MyDataset, custom_collate
9
+ from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
10
+ import wandb
11
+ import torch.nn.functional as F
12
+
13
+ import einops
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
15
+
16
+ np.random.seed(123)
17
+ torch.manual_seed(123)
18
+ torch.cuda.random.manual_seed(123)
19
+
20
+ import lightning as L
21
+ import utils
22
+
23
+
24
+ class PromptTuningModel(nn.Module):
25
+ def __init__(self, num_prompts=6):
26
+ super().__init__()
27
+ self.num_prompts = num_prompts
28
+
29
+ self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
30
+ self.model.requires_grad_(False)
31
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
32
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
33
+
34
+ self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
35
+ self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
36
+
37
+ self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer), pad_to_multiple_of=128)
38
+
39
+ tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
40
+ token_embedding = self.model.transformer.wte(tmp[0])
41
+ self.token_embedding = token_embedding
42
+ for _ in range(num_prompts//3-1):
43
+ self.token_embedding = torch.cat([self.token_embedding, token_embedding])
44
+
45
+ # print(self.token_embedding.shape)
46
+ data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
47
+ self.learnable_prompt = nn.Parameter(data, requires_grad=True)
48
+
49
+ # @torch.compile
50
+ def forward(self, X):
51
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
52
+ embeddings = self.model.transformer.wte(X, ) # b s d
53
+
54
+ embeddings = torch.cat([embeddings, self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1)], dim=1)
55
+ mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
56
+ # print(mask.shape)
57
+ logits = self.model(inputs_embeds = embeddings, attention_mask=mask).logits[:,self.num_prompts:].swapaxes(1,2)
58
+ return logits
59
+
60
+ class LitModelPromptTuning(L.LightningModule):
61
+ def __init__(self, model, lr=1e-4):
62
+ super().__init__()
63
+ self.model = model
64
+ self.lr = lr
65
+
66
+ self.save_hyperparameters(ignore=['model'])
67
+
68
+
69
+ def training_step(self, batch, batch_idx):
70
+ X, y = batch
71
+ # for i,j in zip(X[0], y[0]):
72
+ # print(i.item(),j.item())
73
+ # print(self.model.pad, self.model.eot)
74
+ logits = self.model(X)
75
+ loss = F.cross_entropy(logits, target=y, ignore_index=50257)
76
+ # print(loss)
77
+ # exit()
78
+
79
+ self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
80
+ return loss
81
+
82
+
83
+ def validation_step(self, batch, batch_idx):
84
+ X, y = batch
85
+
86
+ logits = self.model(X)
87
+ loss = F.cross_entropy(logits, target=y, ignore_index=50257)
88
+
89
+ self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
90
+ return loss
91
+
92
+
93
+
94
+ def configure_optimizers(self):
95
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
96
+ return optimizer
97
+
98
+ from lightning.pytorch.loggers import WandbLogger
99
+ if __name__ == '__main__':
100
+ torch.set_float32_matmul_precision('medium')
101
+ dl_train, dl_val, dl_test = utils.import_data(bs=20, fraction=0.1)
102
+ gpt_model = PromptTuningModel()
103
+ # gpt_model = torch.compile(gpt_model)
104
+ model = LitModelPromptTuning(model=gpt_model)
105
+ print('Training')
106
+
107
+ logger = WandbLogger(project='Anlp-3')
108
+ trainer = L.Trainer(
109
+ accelerator='gpu',
110
+ # strategy='auto',
111
+ # strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
112
+ devices=[3],
113
+ default_root_dir=f'./logs/', # Tensorflow can be used to viz
114
+ num_nodes=1,
115
+ num_sanity_val_steps=1, # runs a validation step before stating training
116
+ precision='16-mixed', # we use half precision to reduce memory usage
117
+ max_epochs=10,
118
+ check_val_every_n_epoch=1, # run validation every epoch
119
+ log_every_n_steps=20,
120
+ logger=logger
121
+ # detect_anomaly=True,
122
+ )
123
+
124
+ trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
wandb/run-20241028_085806-owcrwbil/files/config.yaml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _wandb:
2
+ value:
3
+ cli_version: 0.18.1
4
+ code_path: code/main.py
5
+ m:
6
+ - "1": epoch
7
+ "5": 2
8
+ "6":
9
+ - 1
10
+ - 3
11
+ "7": []
12
+ - "1": trainer/global_step
13
+ "6":
14
+ - 3
15
+ "7": []
16
+ - "1": Training loss_step
17
+ "5": 2
18
+ "6":
19
+ - 1
20
+ - 3
21
+ "7": []
22
+ - "1": Validation loss_step
23
+ "5": 2
24
+ "6":
25
+ - 1
26
+ - 3
27
+ "7": []
28
+ - "1": Validation loss_epoch
29
+ "5": 2
30
+ "6":
31
+ - 1
32
+ - 3
33
+ "7": []
34
+ - "1": Training loss_epoch
35
+ "5": 2
36
+ "6":
37
+ - 1
38
+ - 3
39
+ "7": []
40
+ python_version: 3.10.13
41
+ t:
42
+ "1":
43
+ - 1
44
+ - 11
45
+ - 49
46
+ - 55
47
+ - 71
48
+ - 106
49
+ "2":
50
+ - 1
51
+ - 11
52
+ - 49
53
+ - 55
54
+ - 71
55
+ - 106
56
+ "3":
57
+ - 7
58
+ - 23
59
+ - 55
60
+ - 66
61
+ "4": 3.10.13
62
+ "5": 0.18.1
63
+ "6": 4.44.2
64
+ "8":
65
+ - 5
66
+ "12": 0.18.1
67
+ "13": linux-x86_64
68
+ lr:
69
+ value: 0.0001
wandb/run-20241028_085806-owcrwbil/files/output.log ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
2
+
3
+ | Name | Type | Params | Mode
4
+ ----------------------------------------------------
5
+ 0 | model | PromptTuningModel | 124 M | train
6
+ ----------------------------------------------------
7
+ 4.6 K Trainable params
8
+ 124 M Non-trainable params
9
+ 124 M Total params
10
+ 497.922 Total estimated model params size (MB)
11
+ 1 Modules in train mode
12
+ 164 Modules in eval mode
13
+ Epoch 1: 56%|███████████████████████████████████████████████████████▋ | 59/105 [00:24<00:19, 2.36it/s, v_num=wbil]
14
+
15
+
16
+ Detected KeyboardInterrupt, attempting graceful shutdown ...
wandb/run-20241028_085806-owcrwbil/files/requirements.txt ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ jiter==0.5.0
2
+ anyio==4.6.0
3
+ interegular==0.3.3
4
+ jaxlib==0.4.34
5
+ jsonschema==4.23.0
6
+ typing_extensions==4.12.2
7
+ httpcore==1.0.5
8
+ prometheus_client==0.21.0
9
+ openai==1.51.0
10
+ multidict==6.1.0
11
+ six==1.16.0
12
+ nvidia-nccl-cu12==2.20.5
13
+ nvidia-cuda-runtime-cu12==12.1.105
14
+ nvidia-cuda-cupti-cu12==12.1.105
15
+ nvidia-cudnn-cu12==9.1.0.70
16
+ watchfiles==0.24.0
17
+ tqdm==4.66.5
18
+ yarl==1.11.1
19
+ cffi==1.17.1
20
+ vllm==0.6.1.post2
21
+ bleach==6.1.0
22
+ kaggle==1.6.17
23
+ pydantic_core==2.23.4
24
+ lightning-utilities==0.11.7
25
+ sentry-sdk==2.14.0
26
+ torch==2.4.0
27
+ aiohappyeyeballs==2.4.0
28
+ diffusers==0.15.0
29
+ GitPython==3.1.43
30
+ attrs==24.2.0
31
+ importlib_metadata==8.5.0
32
+ transformers==4.44.2
33
+ pillow==10.4.0
34
+ sounddevice==0.5.1
35
+ gguf==0.9.1
36
+ python-dotenv==1.0.1
37
+ async-timeout==4.0.3
38
+ dspy-ai==2.5.3
39
+ numpy==1.26.4
40
+ nvidia-nvjitlink-cu12==12.6.68
41
+ uvicorn==0.30.6
42
+ kiwisolver==1.4.7
43
+ partial-json-parser==0.2.1.1.post4
44
+ pyparsing==3.1.4
45
+ lightning==2.4.0
46
+ structlog==24.4.0
47
+ nvidia-curand-cu12==10.3.2.106
48
+ setuptools==65.5.0
49
+ webencodings==0.5.1
50
+ nvidia-nvtx-cu12==12.1.105
51
+ sniffio==1.3.1
52
+ MarkupSafe==2.1.5
53
+ vllm-flash-attn==2.6.1
54
+ urllib3==2.2.3
55
+ requests==2.32.3
56
+ pycountry==24.6.1
57
+ ujson==5.10.0
58
+ matplotlib==3.9.2
59
+ pydantic==2.9.2
60
+ torchvision==0.19.0
61
+ numba==0.60.0
62
+ optuna==4.0.0
63
+ opt_einsum==3.4.0
64
+ joblib==1.4.2
65
+ msgpack==1.1.0
66
+ smmap==5.0.1
67
+ filelock==3.16.1
68
+ opencv-contrib-python==4.10.0.84
69
+ faiss-gpu==1.7.2
70
+ prometheus-fastapi-instrumentator==7.0.0
71
+ rpds-py==0.20.0
72
+ psutil==6.0.0
73
+ colorlog==6.8.2
74
+ nvidia-cufft-cu12==11.0.2.54
75
+ SQLAlchemy==2.0.35
76
+ llvmlite==0.43.0
77
+ packaging==24.1
78
+ exceptiongroup==1.2.2
79
+ dill==0.3.8
80
+ ml_dtypes==0.5.0
81
+ pyairports==2.1.1
82
+ scikit-learn==1.5.2
83
+ prettytable==3.11.0
84
+ protobuf==4.25.5
85
+ charset-normalizer==3.3.2
86
+ torchmetrics==1.4.2
87
+ text-unidecode==1.3
88
+ httpx==0.27.2
89
+ sympy==1.13.3
90
+ msgspec==0.18.6
91
+ wandb==0.18.1
92
+ backoff==2.2.1
93
+ sentencepiece==0.2.0
94
+ aiohttp==3.10.5
95
+ distro==1.9.0
96
+ lark==1.2.2
97
+ pyarrow==17.0.0
98
+ Mako==1.3.5
99
+ regex==2024.9.11
100
+ safetensors==0.4.5
101
+ aiosignal==1.3.1
102
+ jsonschema-specifications==2023.12.1
103
+ cloudpickle==3.0.0
104
+ einops==0.8.0
105
+ ray==2.36.1
106
+ fire==0.7.0
107
+ pyzmq==26.2.0
108
+ pycparser==2.22
109
+ platformdirs==4.3.6
110
+ click==8.1.7
111
+ fastapi==0.115.0
112
+ ftfy==6.3.0
113
+ torchtext==0.18.0
114
+ lm-format-enforcer==0.10.6
115
+ fsspec==2024.6.1
116
+ tzdata==2024.2
117
+ starlette==0.38.6
118
+ cycler==0.12.1
119
+ py-cpuinfo==9.0.0
120
+ h11==0.14.0
121
+ huggingface-hub==0.25.1
122
+ nvidia-cusparse-cu12==12.1.0.106
123
+ nvidia-ml-py==12.560.30
124
+ certifi==2024.8.30
125
+ httptools==0.6.1
126
+ jax==0.4.34
127
+ PyYAML==6.0.2
128
+ xxhash==3.5.0
129
+ idna==3.10
130
+ xformers==0.0.27.post2
131
+ mistral_common==1.4.3
132
+ fonttools==4.54.0
133
+ pip==23.0.1
134
+ accelerate==0.34.2
135
+ mediapipe==0.10.15
136
+ pytorch-lightning==2.4.0
137
+ ollama==0.3.3
138
+ Jinja2==3.1.4
139
+ multiprocess==0.70.16
140
+ opencv-python==4.10.0.84
141
+ termcolor==2.5.0
142
+ python-dateutil==2.9.0.post0
143
+ contourpy==1.3.0
144
+ websockets==13.1
145
+ frozenlist==1.4.1
146
+ pandas==2.2.3
147
+ networkx==3.3
148
+ diskcache==5.6.3
149
+ nvidia-cusolver-cu12==11.4.5.107
150
+ flatbuffers==24.3.25
151
+ mpmath==1.3.0
152
+ setproctitle==1.3.3
153
+ tokenizers==0.19.1
154
+ scipy==1.14.1
155
+ outlines==0.0.46
156
+ annotated-types==0.7.0
157
+ docker-pycreds==0.4.0
158
+ magicattr==0.1.6
159
+ wcwidth==0.2.13
160
+ pytorch-metric-learning==2.6.0
161
+ datasets==3.0.0
162
+ gitdb==4.0.11
163
+ lora-diffusion==0.1.7
164
+ referencing==0.35.1
165
+ python-slugify==8.0.4
166
+ zipp==3.20.2
167
+ triton==3.0.0
168
+ absl-py==2.1.0
169
+ threadpoolctl==3.5.0
170
+ uvloop==0.20.0
171
+ tiktoken==0.7.0
172
+ pytz==2024.2
173
+ nest-asyncio==1.6.0
174
+ nvidia-cublas-cu12==12.1.3.1
175
+ litellm==1.48.12
176
+ nvidia-cuda-nvrtc-cu12==12.1.105
177
+ greenlet==3.1.1
178
+ alembic==1.13.3
wandb/run-20241028_085806-owcrwbil/files/wandb-metadata.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.15.161-ql-generic-13.0-14-x86_64-with-glibc2.35",
3
+ "python": "3.10.13",
4
+ "startedAt": "2024-10-28T04:58:06.613389Z",
5
+ "program": "/home/siddharth.tourani/Kyrylo/nlp/main.py",
6
+ "codePath": "main.py",
7
+ "email": "kyrylo.shyvam@students.iiit.ac.in",
8
+ "root": ".",
9
+ "host": "gpu-08",
10
+ "username": "siddharth.tourani",
11
+ "executable": "/home/siddharth.tourani/Minimal/bin/python3",
12
+ "codePathLocal": "main.py",
13
+ "cpu_count": 128,
14
+ "cpu_count_logical": 256,
15
+ "gpu": "[NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB]",
16
+ "gpu_count": 4,
17
+ "disk": {
18
+ "/": {
19
+ "total": "1073741824",
20
+ "used": "21049344"
21
+ }
22
+ },
23
+ "memory": {
24
+ "total": "270256893952"
25
+ },
26
+ "cpu": {
27
+ "count": 128,
28
+ "countLogical": 256
29
+ },
30
+ "gpu_nvidia": [
31
+ {
32
+ "name": "NVIDIA A100-SXM4-40GB",
33
+ "memoryTotal": "42949672960",
34
+ "cudaCores": 6912,
35
+ "architecture": "Ampere"
36
+ },
37
+ {
38
+ "name": "NVIDIA A100-SXM4-40GB",
39
+ "memoryTotal": "42949672960",
40
+ "cudaCores": 6912,
41
+ "architecture": "Ampere"
42
+ },
43
+ {
44
+ "name": "NVIDIA A100-SXM4-40GB",
45
+ "memoryTotal": "42949672960",
46
+ "cudaCores": 6912,
47
+ "architecture": "Ampere"
48
+ },
49
+ {
50
+ "name": "NVIDIA A100-SXM4-40GB",
51
+ "memoryTotal": "42949672960",
52
+ "cudaCores": 6912,
53
+ "architecture": "Ampere"
54
+ }
55
+ ],
56
+ "slurm": {
57
+ "job_id": "68153"
58
+ },
59
+ "cudaVersion": "12.4"
60
+ }
wandb/run-20241028_085806-owcrwbil/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_runtime":77.016630791,"Validation loss_step":107.9970703125,"Validation loss_epoch":108.69440460205078,"_timestamp":1.7300915636298473e+09,"_wandb":{"runtime":79},"epoch":1,"Training loss_step":107.52567291259766,"trainer/global_step":159,"_step":39,"Training loss_epoch":108.31751251220703}
wandb/run-20241028_085806-owcrwbil/logs/debug-core.log ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2024-10-28T08:58:05.852186282+04:00","level":"INFO","msg":"started logging, with flags","port-filename":"/tmp/tmpbt47w5hb/port-1241278.txt","pid":1241278,"debug":false,"disable-analytics":false}
2
+ {"time":"2024-10-28T08:58:05.852211372+04:00","level":"INFO","msg":"FeatureState","shutdownOnParentExitEnabled":false}
3
+ {"time":"2024-10-28T08:58:05.853103688+04:00","level":"INFO","msg":"Will exit if parent process dies.","ppid":1241278}
4
+ {"time":"2024-10-28T08:58:05.853096439+04:00","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":39955,"Zone":""}}
5
+ {"time":"2024-10-28T08:58:06.04806319+04:00","level":"INFO","msg":"created new connection","id":"127.0.0.1:41868"}
6
+ {"time":"2024-10-28T08:58:06.614533728+04:00","level":"INFO","msg":"connection init received","streamId":"owcrwbil","id":"127.0.0.1:41868"}
7
+ {"time":"2024-10-28T08:58:06.615450814+04:00","level":"ERROR","msg":"error creating symlink","error":"symlink /home/siddharth.tourani/.cache/wandb/logs/core-debug-20241028_085805.log wandb/run-20241028_085806-owcrwbil/logs/debug-core.log: file exists"}
8
+ {"time":"2024-10-28T08:58:06.620019897+04:00","level":"INFO","msg":"connection init completed","streamId":"owcrwbil","id":"127.0.0.1:41868"}
9
+ {"time":"2024-10-28T08:59:25.917611765+04:00","level":"INFO","msg":"connection: teardown","id":"127.0.0.1:41868"}
10
+ {"time":"2024-10-28T08:59:25.917909894+04:00","level":"INFO","msg":"server is shutting down"}
11
+ {"time":"2024-10-28T08:59:25.917970554+04:00","level":"INFO","msg":"closed connection","id":"127.0.0.1:41868"}
12
+ {"time":"2024-10-28T08:59:32.101339343+04:00","level":"INFO","msg":"connection closed","id":"127.0.0.1:41868"}
13
+ {"time":"2024-10-28T08:59:32.101350683+04:00","level":"INFO","msg":"server is closed"}
wandb/run-20241028_085806-owcrwbil/logs/debug-internal.log ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2024-10-28T08:58:06.615297525+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
2
+ {"time":"2024-10-28T08:58:06.615309735+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241028_085806-owcrwbil/logs/debug-core.log"}
3
+ {"time":"2024-10-28T08:58:06.615703263+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
4
+ {"time":"2024-10-28T08:58:06.615710223+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241028_085806-owcrwbil/logs/debug-core.log"}
5
+ {"time":"2024-10-28T08:58:06.619998427+04:00","level":"INFO","msg":"created new stream","id":"owcrwbil"}
6
+ {"time":"2024-10-28T08:58:06.620015787+04:00","level":"INFO","msg":"stream: started","id":"owcrwbil"}
7
+ {"time":"2024-10-28T08:58:06.620034997+04:00","level":"INFO","msg":"sender: started","stream_id":{"value":"owcrwbil"}}
8
+ {"time":"2024-10-28T08:58:06.620043877+04:00","level":"INFO","msg":"handler: started","stream_id":{"value":"owcrwbil"}}
9
+ {"time":"2024-10-28T08:58:06.620038217+04:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"owcrwbil"}}
10
+ {"time":"2024-10-28T08:58:07.277024244+04:00","level":"INFO","msg":"wandb-core","!BADKEY":null}
11
+ {"time":"2024-10-28T08:58:07.279017506+04:00","level":"INFO","msg":"Starting system monitor"}
12
+ {"time":"2024-10-28T08:58:07.28068419+04:00","level":"ERROR","msg":"git repo not found","error":"repository does not exist"}
13
+ {"time":"2024-10-28T08:59:25.917904554+04:00","level":"INFO","msg":"stream: closing","id":"owcrwbil"}
14
+ {"time":"2024-10-28T08:59:25.917952654+04:00","level":"INFO","msg":"Stopping system monitor"}
15
+ {"time":"2024-10-28T08:59:25.919306298+04:00","level":"INFO","msg":"Stopped system monitor"}
16
+ {"time":"2024-10-28T08:59:26.173567721+04:00","level":"WARN","msg":"No job ingredients found, not creating job artifact"}
17
+ {"time":"2024-10-28T08:59:26.173578511+04:00","level":"WARN","msg":"No source type found, not creating job artifact"}
18
+ {"time":"2024-10-28T08:59:26.173581711+04:00","level":"INFO","msg":"sender: sendDefer: no job artifact to save"}
19
+ {"time":"2024-10-28T08:59:32.100682276+04:00","level":"INFO","msg":"handler: closed","stream_id":{"value":"owcrwbil"}}
20
+ {"time":"2024-10-28T08:59:32.100731535+04:00","level":"INFO","msg":"writer: Close: closed","stream_id":{"value":"owcrwbil"}}
21
+ {"time":"2024-10-28T08:59:32.100765025+04:00","level":"INFO","msg":"sender: closed","stream_id":{"value":"owcrwbil"}}
22
+ {"time":"2024-10-28T08:59:32.101266313+04:00","level":"INFO","msg":"stream: closed","id":"owcrwbil"}
wandb/run-20241028_085806-owcrwbil/logs/debug.log ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Current SDK version is 0.18.1
2
+ 2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Configure stats pid to 1241278
3
+ 2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/.config/wandb/settings
4
+ 2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/Kyrylo/nlp/wandb/settings
5
+ 2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Loading settings from environment variables: {}
6
+ 2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Applying setup settings: {'mode': None, '_disable_service': None}
7
+ 2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Inferring run settings from compute environment: {'program_relpath': 'main.py', 'program_abspath': '/home/siddharth.tourani/Kyrylo/nlp/main.py', 'program': '/home/siddharth.tourani/Kyrylo/nlp/main.py'}
8
+ 2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Applying login settings: {}
9
+ 2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_init.py:_log_setup():532] Logging user logs to ./wandb/run-20241028_085806-owcrwbil/logs/debug.log
10
+ 2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_init.py:_log_setup():533] Logging internal logs to ./wandb/run-20241028_085806-owcrwbil/logs/debug-internal.log
11
+ 2024-10-28 08:58:06,611 INFO MainThread:1241278 [wandb_init.py:init():616] calling init triggers
12
+ 2024-10-28 08:58:06,611 INFO MainThread:1241278 [wandb_init.py:init():623] wandb.init called with sweep_config: {}
13
+ config: {}
14
+ 2024-10-28 08:58:06,611 INFO MainThread:1241278 [wandb_init.py:init():666] starting backend
15
+ 2024-10-28 08:58:06,611 INFO MainThread:1241278 [wandb_init.py:init():670] setting up manager
16
+ 2024-10-28 08:58:06,611 INFO MainThread:1241278 [backend.py:_multiprocessing_setup():105] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
17
+ 2024-10-28 08:58:06,613 INFO MainThread:1241278 [wandb_init.py:init():678] backend started and connected
18
+ 2024-10-28 08:58:06,615 INFO MainThread:1241278 [wandb_init.py:init():773] updated telemetry
19
+ 2024-10-28 08:58:06,616 INFO MainThread:1241278 [wandb_init.py:init():806] communicating run to backend with 90.0 second timeout
20
+ 2024-10-28 08:58:07,272 INFO MainThread:1241278 [wandb_init.py:init():857] starting run threads in backend
21
+ 2024-10-28 08:58:07,450 INFO MainThread:1241278 [wandb_run.py:_console_start():2459] atexit reg
22
+ 2024-10-28 08:58:07,450 INFO MainThread:1241278 [wandb_run.py:_redirect():2307] redirect: wrap_raw
23
+ 2024-10-28 08:58:07,450 INFO MainThread:1241278 [wandb_run.py:_redirect():2372] Wrapping output streams.
24
+ 2024-10-28 08:58:07,450 INFO MainThread:1241278 [wandb_run.py:_redirect():2397] Redirects installed.
25
+ 2024-10-28 08:58:07,452 INFO MainThread:1241278 [wandb_init.py:init():900] run started, returning control to user process
26
+ 2024-10-28 08:58:07,549 INFO MainThread:1241278 [wandb_run.py:_config_callback():1388] config_cb None None {'lr': 0.0001}
27
+ 2024-10-28 08:59:25,918 WARNING MsgRouterThr:1241278 [router.py:message_loop():77] message_loop has been closed
wandb/run-20241028_085806-owcrwbil/run-owcrwbil.wandb ADDED
Binary file (211 kB). View file
 
wandb/run-20241028_090044-f9fzz8iy/files/code/main.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import inf
2
+ # from utils import *
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import torch.utils
7
+ import torch.utils.data
8
+ # from utils import MyDataset, custom_collate
9
+ from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
10
+ import wandb
11
+ import torch.nn.functional as F
12
+
13
+ import einops
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
15
+
16
+ np.random.seed(123)
17
+ torch.manual_seed(123)
18
+ torch.cuda.random.manual_seed(123)
19
+
20
+ import lightning as L
21
+ import utils
22
+
23
+
24
+ class PromptTuningModel(nn.Module):
25
+ def __init__(self, num_prompts=6):
26
+ super().__init__()
27
+ self.num_prompts = num_prompts
28
+
29
+ self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
30
+ self.model.requires_grad_(False)
31
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
32
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
33
+
34
+ self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
35
+ self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
36
+
37
+ self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer), pad_to_multiple_of=128)
38
+
39
+ tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
40
+ token_embedding = self.model.transformer.wte(tmp[0])
41
+ self.token_embedding = token_embedding
42
+ for _ in range(num_prompts//3-1):
43
+ self.token_embedding = torch.cat([self.token_embedding, token_embedding])
44
+
45
+ # print(self.token_embedding.shape)
46
+ data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
47
+ self.learnable_prompt = nn.Parameter(data, requires_grad=True)
48
+
49
+ # @torch.compile
50
+ def forward(self, X):
51
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
52
+ embeddings = self.model.transformer.wte(X, ) # b s d
53
+
54
+ embeddings = torch.cat([embeddings, self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1)], dim=1)
55
+ mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
56
+ # print(mask.shape)
57
+ logits = self.model(inputs_embeds = embeddings, attention_mask=mask).logits[:,self.num_prompts:].swapaxes(1,2)
58
+ return logits
59
+
60
+ class LitModelPromptTuning(L.LightningModule):
61
+ def __init__(self, model, lr=1e-4):
62
+ super().__init__()
63
+ self.model = model
64
+ self.lr = lr
65
+
66
+ self.save_hyperparameters(ignore=['model'])
67
+
68
+
69
+ def training_step(self, batch, batch_idx):
70
+ X, y = batch
71
+ # for i,j in zip(X[0], y[0]):
72
+ # print(i.item(),j.item())
73
+ # print(self.model.pad, self.model.eot)
74
+ logits = self.model(X)
75
+ loss = F.cross_entropy(logits, target=y, ignore_index=50257)
76
+ # print(loss)
77
+ # exit()
78
+
79
+ self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
80
+ return loss
81
+
82
+
83
+ def validation_step(self, batch, batch_idx):
84
+ X, y = batch
85
+
86
+ logits = self.model(X)
87
+ loss = F.cross_entropy(logits, target=y, ignore_index=50257)
88
+
89
+ self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
90
+ return loss
91
+
92
+
93
+
94
+ def configure_optimizers(self):
95
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
96
+ return optimizer
97
+
98
+ from lightning.pytorch.loggers import WandbLogger
99
+ if __name__ == '__main__':
100
+ torch.set_float32_matmul_precision('medium')
101
+ dl_train, dl_val, dl_test = utils.import_data(bs=25, fraction=0.1)
102
+ gpt_model = PromptTuningModel()
103
+ gpt_model = torch.compile(gpt_model)
104
+ model = LitModelPromptTuning(model=gpt_model)
105
+ print('Training')
106
+
107
+ logger = WandbLogger(project='Anlp-3')
108
+ trainer = L.Trainer(
109
+ accelerator='gpu',
110
+ # strategy='auto',
111
+ # strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
112
+ devices=[3],
113
+ default_root_dir=f'./logs/', # Tensorflow can be used to viz
114
+ num_nodes=1,
115
+ num_sanity_val_steps=1, # runs a validation step before stating training
116
+ precision='16-mixed', # we use half precision to reduce memory usage
117
+ max_epochs=10,
118
+ check_val_every_n_epoch=1, # run validation every epoch
119
+ log_every_n_steps=20,
120
+ logger=logger
121
+ # detect_anomaly=True,
122
+ )
123
+
124
+ trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
wandb/run-20241028_090044-f9fzz8iy/files/config.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _wandb:
2
+ value:
3
+ cli_version: 0.18.1
4
+ code_path: code/main.py
5
+ m:
6
+ - "1": trainer/global_step
7
+ "6":
8
+ - 3
9
+ "7": []
10
+ - "1": Training loss_step
11
+ "5": 1
12
+ "6":
13
+ - 1
14
+ - 3
15
+ "7": []
16
+ - "1": epoch
17
+ "5": 1
18
+ "6":
19
+ - 1
20
+ - 3
21
+ "7": []
22
+ python_version: 3.10.13
23
+ t:
24
+ "1":
25
+ - 1
26
+ - 11
27
+ - 49
28
+ - 55
29
+ - 71
30
+ - 106
31
+ "2":
32
+ - 1
33
+ - 11
34
+ - 49
35
+ - 55
36
+ - 71
37
+ - 106
38
+ "3":
39
+ - 7
40
+ - 23
41
+ - 55
42
+ - 66
43
+ "4": 3.10.13
44
+ "5": 0.18.1
45
+ "6": 4.44.2
46
+ "8":
47
+ - 5
48
+ "12": 0.18.1
49
+ "13": linux-x86_64
50
+ lr:
51
+ value: 0.0001
wandb/run-20241028_090044-f9fzz8iy/files/output.log ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
2
+
3
+ | Name | Type | Params | Mode
4
+ ----------------------------------------------------
5
+ 0 | model | PromptTuningModel | 124 M | train
6
+ ----------------------------------------------------
7
+ 4.6 K Trainable params
8
+ 124 M Non-trainable params
9
+ 124 M Total params
10
+ 497.922 Total estimated model params size (MB)
11
+ 1 Modules in train mode
12
+ 164 Modules in eval mode
13
+ Epoch 0: 67%|██████████████████████████████████████████████████████████████████▋ | 56/84 [00:27<00:13, 2.07it/s, v_num=z8iy]
14
+
15
+ Detected KeyboardInterrupt, attempting graceful shutdown ...
wandb/run-20241028_090044-f9fzz8iy/files/requirements.txt ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ jiter==0.5.0
2
+ anyio==4.6.0
3
+ interegular==0.3.3
4
+ jaxlib==0.4.34
5
+ jsonschema==4.23.0
6
+ typing_extensions==4.12.2
7
+ httpcore==1.0.5
8
+ prometheus_client==0.21.0
9
+ openai==1.51.0
10
+ multidict==6.1.0
11
+ six==1.16.0
12
+ nvidia-nccl-cu12==2.20.5
13
+ nvidia-cuda-runtime-cu12==12.1.105
14
+ nvidia-cuda-cupti-cu12==12.1.105
15
+ nvidia-cudnn-cu12==9.1.0.70
16
+ watchfiles==0.24.0
17
+ tqdm==4.66.5
18
+ yarl==1.11.1
19
+ cffi==1.17.1
20
+ vllm==0.6.1.post2
21
+ bleach==6.1.0
22
+ kaggle==1.6.17
23
+ pydantic_core==2.23.4
24
+ lightning-utilities==0.11.7
25
+ sentry-sdk==2.14.0
26
+ torch==2.4.0
27
+ aiohappyeyeballs==2.4.0
28
+ diffusers==0.15.0
29
+ GitPython==3.1.43
30
+ attrs==24.2.0
31
+ importlib_metadata==8.5.0
32
+ transformers==4.44.2
33
+ pillow==10.4.0
34
+ sounddevice==0.5.1
35
+ gguf==0.9.1
36
+ python-dotenv==1.0.1
37
+ async-timeout==4.0.3
38
+ dspy-ai==2.5.3
39
+ numpy==1.26.4
40
+ nvidia-nvjitlink-cu12==12.6.68
41
+ uvicorn==0.30.6
42
+ kiwisolver==1.4.7
43
+ partial-json-parser==0.2.1.1.post4
44
+ pyparsing==3.1.4
45
+ lightning==2.4.0
46
+ structlog==24.4.0
47
+ nvidia-curand-cu12==10.3.2.106
48
+ setuptools==65.5.0
49
+ webencodings==0.5.1
50
+ nvidia-nvtx-cu12==12.1.105
51
+ sniffio==1.3.1
52
+ MarkupSafe==2.1.5
53
+ vllm-flash-attn==2.6.1
54
+ urllib3==2.2.3
55
+ requests==2.32.3
56
+ pycountry==24.6.1
57
+ ujson==5.10.0
58
+ matplotlib==3.9.2
59
+ pydantic==2.9.2
60
+ torchvision==0.19.0
61
+ numba==0.60.0
62
+ optuna==4.0.0
63
+ opt_einsum==3.4.0
64
+ joblib==1.4.2
65
+ msgpack==1.1.0
66
+ smmap==5.0.1
67
+ filelock==3.16.1
68
+ opencv-contrib-python==4.10.0.84
69
+ faiss-gpu==1.7.2
70
+ prometheus-fastapi-instrumentator==7.0.0
71
+ rpds-py==0.20.0
72
+ psutil==6.0.0
73
+ colorlog==6.8.2
74
+ nvidia-cufft-cu12==11.0.2.54
75
+ SQLAlchemy==2.0.35
76
+ llvmlite==0.43.0
77
+ packaging==24.1
78
+ exceptiongroup==1.2.2
79
+ dill==0.3.8
80
+ ml_dtypes==0.5.0
81
+ pyairports==2.1.1
82
+ scikit-learn==1.5.2
83
+ prettytable==3.11.0
84
+ protobuf==4.25.5
85
+ charset-normalizer==3.3.2
86
+ torchmetrics==1.4.2
87
+ text-unidecode==1.3
88
+ httpx==0.27.2
89
+ sympy==1.13.3
90
+ msgspec==0.18.6
91
+ wandb==0.18.1
92
+ backoff==2.2.1
93
+ sentencepiece==0.2.0
94
+ aiohttp==3.10.5
95
+ distro==1.9.0
96
+ lark==1.2.2
97
+ pyarrow==17.0.0
98
+ Mako==1.3.5
99
+ regex==2024.9.11
100
+ safetensors==0.4.5
101
+ aiosignal==1.3.1
102
+ jsonschema-specifications==2023.12.1
103
+ cloudpickle==3.0.0
104
+ einops==0.8.0
105
+ ray==2.36.1
106
+ fire==0.7.0
107
+ pyzmq==26.2.0
108
+ pycparser==2.22
109
+ platformdirs==4.3.6
110
+ click==8.1.7
111
+ fastapi==0.115.0
112
+ ftfy==6.3.0
113
+ torchtext==0.18.0
114
+ lm-format-enforcer==0.10.6
115
+ fsspec==2024.6.1
116
+ tzdata==2024.2
117
+ starlette==0.38.6
118
+ cycler==0.12.1
119
+ py-cpuinfo==9.0.0
120
+ h11==0.14.0
121
+ huggingface-hub==0.25.1
122
+ nvidia-cusparse-cu12==12.1.0.106
123
+ nvidia-ml-py==12.560.30
124
+ certifi==2024.8.30
125
+ httptools==0.6.1
126
+ jax==0.4.34
127
+ PyYAML==6.0.2
128
+ xxhash==3.5.0
129
+ idna==3.10
130
+ xformers==0.0.27.post2
131
+ mistral_common==1.4.3
132
+ fonttools==4.54.0
133
+ pip==23.0.1
134
+ accelerate==0.34.2
135
+ mediapipe==0.10.15
136
+ pytorch-lightning==2.4.0
137
+ ollama==0.3.3
138
+ Jinja2==3.1.4
139
+ multiprocess==0.70.16
140
+ opencv-python==4.10.0.84
141
+ termcolor==2.5.0
142
+ python-dateutil==2.9.0.post0
143
+ contourpy==1.3.0
144
+ websockets==13.1
145
+ frozenlist==1.4.1
146
+ pandas==2.2.3
147
+ networkx==3.3
148
+ diskcache==5.6.3
149
+ nvidia-cusolver-cu12==11.4.5.107
150
+ flatbuffers==24.3.25
151
+ mpmath==1.3.0
152
+ setproctitle==1.3.3
153
+ tokenizers==0.19.1
154
+ scipy==1.14.1
155
+ outlines==0.0.46
156
+ annotated-types==0.7.0
157
+ docker-pycreds==0.4.0
158
+ magicattr==0.1.6
159
+ wcwidth==0.2.13
160
+ pytorch-metric-learning==2.6.0
161
+ datasets==3.0.0
162
+ gitdb==4.0.11
163
+ lora-diffusion==0.1.7
164
+ referencing==0.35.1
165
+ python-slugify==8.0.4
166
+ zipp==3.20.2
167
+ triton==3.0.0
168
+ absl-py==2.1.0
169
+ threadpoolctl==3.5.0
170
+ uvloop==0.20.0
171
+ tiktoken==0.7.0
172
+ pytz==2024.2
173
+ nest-asyncio==1.6.0
174
+ nvidia-cublas-cu12==12.1.3.1
175
+ litellm==1.48.12
176
+ nvidia-cuda-nvrtc-cu12==12.1.105
177
+ greenlet==3.1.1
178
+ alembic==1.13.3
wandb/run-20241028_090044-f9fzz8iy/files/wandb-metadata.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.15.161-ql-generic-13.0-14-x86_64-with-glibc2.35",
3
+ "python": "3.10.13",
4
+ "startedAt": "2024-10-28T05:00:44.986886Z",
5
+ "program": "/home/siddharth.tourani/Kyrylo/nlp/main.py",
6
+ "codePath": "main.py",
7
+ "email": "kyrylo.shyvam@students.iiit.ac.in",
8
+ "root": ".",
9
+ "host": "gpu-08",
10
+ "username": "siddharth.tourani",
11
+ "executable": "/home/siddharth.tourani/Minimal/bin/python3",
12
+ "codePathLocal": "main.py",
13
+ "cpu_count": 128,
14
+ "cpu_count_logical": 256,
15
+ "gpu": "[NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB]",
16
+ "gpu_count": 4,
17
+ "disk": {
18
+ "/": {
19
+ "total": "1073741824",
20
+ "used": "21049344"
21
+ }
22
+ },
23
+ "memory": {
24
+ "total": "270256893952"
25
+ },
26
+ "cpu": {
27
+ "count": 128,
28
+ "countLogical": 256
29
+ },
30
+ "gpu_nvidia": [
31
+ {
32
+ "name": "NVIDIA A100-SXM4-40GB",
33
+ "memoryTotal": "42949672960",
34
+ "cudaCores": 6912,
35
+ "architecture": "Ampere"
36
+ },
37
+ {
38
+ "name": "NVIDIA A100-SXM4-40GB",
39
+ "memoryTotal": "42949672960",
40
+ "cudaCores": 6912,
41
+ "architecture": "Ampere"
42
+ },
43
+ {
44
+ "name": "NVIDIA A100-SXM4-40GB",
45
+ "memoryTotal": "42949672960",
46
+ "cudaCores": 6912,
47
+ "architecture": "Ampere"
48
+ },
49
+ {
50
+ "name": "NVIDIA A100-SXM4-40GB",
51
+ "memoryTotal": "42949672960",
52
+ "cudaCores": 6912,
53
+ "architecture": "Ampere"
54
+ }
55
+ ],
56
+ "slurm": {
57
+ "job_id": "68153"
58
+ },
59
+ "cudaVersion": "12.4"
60
+ }
wandb/run-20241028_090044-f9fzz8iy/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_wandb":{"runtime":30},"_timestamp":1.7300916668905182e+09,"_runtime":21.903801157,"_step":1,"Training loss_step":104.03557586669922,"epoch":0,"trainer/global_step":39}
wandb/run-20241028_090044-f9fzz8iy/logs/debug-core.log ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2024-10-28T09:00:44.271373064+04:00","level":"INFO","msg":"started logging, with flags","port-filename":"/tmp/tmpr43jmo4e/port-1243414.txt","pid":1243414,"debug":false,"disable-analytics":false}
2
+ {"time":"2024-10-28T09:00:44.271400724+04:00","level":"INFO","msg":"FeatureState","shutdownOnParentExitEnabled":false}
3
+ {"time":"2024-10-28T09:00:44.272333799+04:00","level":"INFO","msg":"Will exit if parent process dies.","ppid":1243414}
4
+ {"time":"2024-10-28T09:00:44.27232289+04:00","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":36087,"Zone":""}}
5
+ {"time":"2024-10-28T09:00:44.467138216+04:00","level":"INFO","msg":"created new connection","id":"127.0.0.1:52740"}
6
+ {"time":"2024-10-28T09:00:44.988394668+04:00","level":"INFO","msg":"connection init received","streamId":"f9fzz8iy","id":"127.0.0.1:52740"}
7
+ {"time":"2024-10-28T09:00:44.989450074+04:00","level":"ERROR","msg":"error creating symlink","error":"symlink /home/siddharth.tourani/.cache/wandb/logs/core-debug-20241028_090044.log wandb/run-20241028_090044-f9fzz8iy/logs/debug-core.log: file exists"}
8
+ {"time":"2024-10-28T09:00:44.994677699+04:00","level":"INFO","msg":"connection init completed","streamId":"f9fzz8iy","id":"127.0.0.1:52740"}
9
+ {"time":"2024-10-28T09:01:15.297081483+04:00","level":"INFO","msg":"connection: teardown","id":"127.0.0.1:52740"}
10
+ {"time":"2024-10-28T09:01:15.297356992+04:00","level":"INFO","msg":"server is shutting down"}
11
+ {"time":"2024-10-28T09:01:15.297394512+04:00","level":"INFO","msg":"closed connection","id":"127.0.0.1:52740"}
12
+ {"time":"2024-10-28T09:01:16.667813027+04:00","level":"INFO","msg":"connection closed","id":"127.0.0.1:52740"}
13
+ {"time":"2024-10-28T09:01:16.667830647+04:00","level":"INFO","msg":"server is closed"}
wandb/run-20241028_090044-f9fzz8iy/logs/debug-internal.log ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2024-10-28T09:00:44.989289054+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
2
+ {"time":"2024-10-28T09:00:44.989302624+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241028_090044-f9fzz8iy/logs/debug-core.log"}
3
+ {"time":"2024-10-28T09:00:44.989755452+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
4
+ {"time":"2024-10-28T09:00:44.989762312+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241028_090044-f9fzz8iy/logs/debug-core.log"}
5
+ {"time":"2024-10-28T09:00:44.994658889+04:00","level":"INFO","msg":"created new stream","id":"f9fzz8iy"}
6
+ {"time":"2024-10-28T09:00:44.994674479+04:00","level":"INFO","msg":"stream: started","id":"f9fzz8iy"}
7
+ {"time":"2024-10-28T09:00:44.994693389+04:00","level":"INFO","msg":"sender: started","stream_id":{"value":"f9fzz8iy"}}
8
+ {"time":"2024-10-28T09:00:44.994713418+04:00","level":"INFO","msg":"handler: started","stream_id":{"value":"f9fzz8iy"}}
9
+ {"time":"2024-10-28T09:00:44.994695909+04:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"f9fzz8iy"}}
10
+ {"time":"2024-10-28T09:00:45.712012987+04:00","level":"INFO","msg":"wandb-core","!BADKEY":null}
11
+ {"time":"2024-10-28T09:00:45.713870838+04:00","level":"INFO","msg":"Starting system monitor"}
12
+ {"time":"2024-10-28T09:00:45.716597735+04:00","level":"ERROR","msg":"git repo not found","error":"repository does not exist"}
13
+ {"time":"2024-10-28T09:01:15.297352912+04:00","level":"INFO","msg":"stream: closing","id":"f9fzz8iy"}
14
+ {"time":"2024-10-28T09:01:15.297391541+04:00","level":"INFO","msg":"Stopping system monitor"}
15
+ {"time":"2024-10-28T09:01:15.2976609+04:00","level":"INFO","msg":"Stopped system monitor"}
16
+ {"time":"2024-10-28T09:01:15.570782955+04:00","level":"WARN","msg":"No job ingredients found, not creating job artifact"}
17
+ {"time":"2024-10-28T09:01:15.570791505+04:00","level":"WARN","msg":"No source type found, not creating job artifact"}
18
+ {"time":"2024-10-28T09:01:15.570794545+04:00","level":"INFO","msg":"sender: sendDefer: no job artifact to save"}
19
+ {"time":"2024-10-28T09:01:16.667000401+04:00","level":"INFO","msg":"handler: closed","stream_id":{"value":"f9fzz8iy"}}
20
+ {"time":"2024-10-28T09:01:16.667029211+04:00","level":"INFO","msg":"writer: Close: closed","stream_id":{"value":"f9fzz8iy"}}
21
+ {"time":"2024-10-28T09:01:16.667090251+04:00","level":"INFO","msg":"sender: closed","stream_id":{"value":"f9fzz8iy"}}
22
+ {"time":"2024-10-28T09:01:16.667475389+04:00","level":"INFO","msg":"stream: closed","id":"f9fzz8iy"}
wandb/run-20241028_090044-f9fzz8iy/logs/debug.log ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Current SDK version is 0.18.1
2
+ 2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Configure stats pid to 1243414
3
+ 2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/.config/wandb/settings
4
+ 2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/Kyrylo/nlp/wandb/settings
5
+ 2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Loading settings from environment variables: {}
6
+ 2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Applying setup settings: {'mode': None, '_disable_service': None}
7
+ 2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Inferring run settings from compute environment: {'program_relpath': 'main.py', 'program_abspath': '/home/siddharth.tourani/Kyrylo/nlp/main.py', 'program': '/home/siddharth.tourani/Kyrylo/nlp/main.py'}
8
+ 2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Applying login settings: {}
9
+ 2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_init.py:_log_setup():532] Logging user logs to ./wandb/run-20241028_090044-f9fzz8iy/logs/debug.log
10
+ 2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_init.py:_log_setup():533] Logging internal logs to ./wandb/run-20241028_090044-f9fzz8iy/logs/debug-internal.log
11
+ 2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_init.py:init():616] calling init triggers
12
+ 2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_init.py:init():623] wandb.init called with sweep_config: {}
13
+ config: {}
14
+ 2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_init.py:init():666] starting backend
15
+ 2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_init.py:init():670] setting up manager
16
+ 2024-10-28 09:00:44,985 INFO MainThread:1243414 [backend.py:_multiprocessing_setup():105] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
17
+ 2024-10-28 09:00:44,986 INFO MainThread:1243414 [wandb_init.py:init():678] backend started and connected
18
+ 2024-10-28 09:00:44,989 INFO MainThread:1243414 [wandb_init.py:init():773] updated telemetry
19
+ 2024-10-28 09:00:44,989 INFO MainThread:1243414 [wandb_init.py:init():806] communicating run to backend with 90.0 second timeout
20
+ 2024-10-28 09:00:45,707 INFO MainThread:1243414 [wandb_init.py:init():857] starting run threads in backend
21
+ 2024-10-28 09:00:45,884 INFO MainThread:1243414 [wandb_run.py:_console_start():2459] atexit reg
22
+ 2024-10-28 09:00:45,884 INFO MainThread:1243414 [wandb_run.py:_redirect():2307] redirect: wrap_raw
23
+ 2024-10-28 09:00:45,884 INFO MainThread:1243414 [wandb_run.py:_redirect():2372] Wrapping output streams.
24
+ 2024-10-28 09:00:45,884 INFO MainThread:1243414 [wandb_run.py:_redirect():2397] Redirects installed.
25
+ 2024-10-28 09:00:45,885 INFO MainThread:1243414 [wandb_init.py:init():900] run started, returning control to user process
26
+ 2024-10-28 09:00:45,983 INFO MainThread:1243414 [wandb_run.py:_config_callback():1388] config_cb None None {'lr': 0.0001}
27
+ 2024-10-28 09:01:15,297 WARNING MsgRouterThr:1243414 [router.py:message_loop():77] message_loop has been closed
wandb/run-20241028_090044-f9fzz8iy/run-f9fzz8iy.wandb ADDED
Binary file (60.2 kB). View file
 
wandb/run-20241028_090149-4jbvn26d/files/code/main.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import inf
2
+ # from utils import *
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import torch.utils
7
+ import torch.utils.data
8
+ # from utils import MyDataset, custom_collate
9
+ from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
10
+ import wandb
11
+ import torch.nn.functional as F
12
+
13
+ import einops
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
15
+
16
+ np.random.seed(123)
17
+ torch.manual_seed(123)
18
+ torch.cuda.random.manual_seed(123)
19
+
20
+ import lightning as L
21
+ import utils
22
+
23
+
24
+ class PromptTuningModel(nn.Module):
25
+ def __init__(self, num_prompts=6):
26
+ super().__init__()
27
+ self.num_prompts = num_prompts
28
+
29
+ self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
30
+ self.model.requires_grad_(False)
31
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
32
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
33
+
34
+ self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
35
+ self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
36
+
37
+ self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer), pad_to_multiple_of=128)
38
+
39
+ tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
40
+ token_embedding = self.model.transformer.wte(tmp[0])
41
+ self.token_embedding = token_embedding
42
+ for _ in range(num_prompts//3-1):
43
+ self.token_embedding = torch.cat([self.token_embedding, token_embedding])
44
+
45
+ # print(self.token_embedding.shape)
46
+ data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
47
+ self.learnable_prompt = nn.Parameter(data, requires_grad=True)
48
+
49
+ # @torch.compile
50
+ def forward(self, X):
51
+ self.learnable_prompt = self.learnable_prompt.to(X.device)
52
+ embeddings = self.model.transformer.wte(X, ) # b s d
53
+
54
+ embeddings = torch.cat([embeddings, self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1)], dim=1)
55
+ mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
56
+ # print(mask.shape)
57
+ logits = self.model(inputs_embeds = embeddings, attention_mask=mask).logits[:,self.num_prompts:].swapaxes(1,2)
58
+ return logits
59
+
60
+ class LitModelPromptTuning(L.LightningModule):
61
+ def __init__(self, model, lr=1e-4):
62
+ super().__init__()
63
+ self.model = model
64
+ self.lr = lr
65
+
66
+ self.save_hyperparameters(ignore=['model'])
67
+
68
+
69
+ def training_step(self, batch, batch_idx):
70
+ X, y = batch
71
+ # for i,j in zip(X[0], y[0]):
72
+ # print(i.item(),j.item())
73
+ # print(self.model.pad, self.model.eot)
74
+ logits = self.model(X)
75
+ loss = F.cross_entropy(logits, target=y, ignore_index=50257)
76
+ # print(loss)
77
+ # exit()
78
+
79
+ self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
80
+ return loss
81
+
82
+
83
+ def validation_step(self, batch, batch_idx):
84
+ X, y = batch
85
+
86
+ logits = self.model(X)
87
+ loss = F.cross_entropy(logits, target=y, ignore_index=50257)
88
+
89
+ self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
90
+ return loss
91
+
92
+
93
+
94
+ def configure_optimizers(self):
95
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
96
+ return optimizer
97
+
98
+ from lightning.pytorch.loggers import WandbLogger
99
+ if __name__ == '__main__':
100
+ torch.set_float32_matmul_precision('medium')
101
+ dl_train, dl_val, dl_test = utils.import_data(bs=25, fraction=0.1)
102
+ gpt_model = PromptTuningModel()
103
+ gpt_model = torch.compile(gpt_model)
104
+ model = LitModelPromptTuning(model=gpt_model)
105
+ print('Training')
106
+
107
+ logger = WandbLogger(project='Anlp-3')
108
+ trainer = L.Trainer(
109
+ accelerator='gpu',
110
+ # strategy='auto',
111
+ # strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
112
+ devices=[3],
113
+ default_root_dir=f'./logs/', # Tensorflow can be used to viz
114
+ num_nodes=1,
115
+ num_sanity_val_steps=1, # runs a validation step before stating training
116
+ precision='16-mixed', # we use half precision to reduce memory usage
117
+ max_epochs=10,
118
+ check_val_every_n_epoch=1, # run validation every epoch
119
+ log_every_n_steps=20,
120
+ logger=logger
121
+ # detect_anomaly=True,
122
+ )
123
+
124
+ trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)