Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +42 -0
- __pycache__/main.cpython-310.pyc +0 -0
- __pycache__/main1.cpython-310.pyc +0 -0
- __pycache__/main2.cpython-310.pyc +0 -0
- __pycache__/utils.cpython-310.pyc +0 -0
- cnn_dailymail/test.csv +3 -0
- cnn_dailymail/train.csv +3 -0
- cnn_dailymail/validation.csv +3 -0
- last_layer.py +399 -0
- main.py +370 -0
- main1.py +425 -0
- main2.py +511 -0
- model0.bin +3 -0
- model1.bin +3 -0
- model2.bin +3 -0
- newspaper-text-summarization-cnn-dailymail.zip +3 -0
- utils.py +124 -0
- wandb/debug-internal.log +22 -0
- wandb/debug.log +27 -0
- wandb/run-20241028_085547-mga40p7t/files/code/main.py +124 -0
- wandb/run-20241028_085547-mga40p7t/files/config.yaml +69 -0
- wandb/run-20241028_085547-mga40p7t/files/output.log +16 -0
- wandb/run-20241028_085547-mga40p7t/files/requirements.txt +178 -0
- wandb/run-20241028_085547-mga40p7t/files/wandb-metadata.json +60 -0
- wandb/run-20241028_085547-mga40p7t/files/wandb-summary.json +1 -0
- wandb/run-20241028_085547-mga40p7t/logs/debug-core.log +14 -0
- wandb/run-20241028_085547-mga40p7t/logs/debug-internal.log +22 -0
- wandb/run-20241028_085547-mga40p7t/logs/debug.log +27 -0
- wandb/run-20241028_085547-mga40p7t/run-mga40p7t.wandb +0 -0
- wandb/run-20241028_085806-owcrwbil/files/code/main.py +124 -0
- wandb/run-20241028_085806-owcrwbil/files/config.yaml +69 -0
- wandb/run-20241028_085806-owcrwbil/files/output.log +16 -0
- wandb/run-20241028_085806-owcrwbil/files/requirements.txt +178 -0
- wandb/run-20241028_085806-owcrwbil/files/wandb-metadata.json +60 -0
- wandb/run-20241028_085806-owcrwbil/files/wandb-summary.json +1 -0
- wandb/run-20241028_085806-owcrwbil/logs/debug-core.log +13 -0
- wandb/run-20241028_085806-owcrwbil/logs/debug-internal.log +22 -0
- wandb/run-20241028_085806-owcrwbil/logs/debug.log +27 -0
- wandb/run-20241028_085806-owcrwbil/run-owcrwbil.wandb +0 -0
- wandb/run-20241028_090044-f9fzz8iy/files/code/main.py +124 -0
- wandb/run-20241028_090044-f9fzz8iy/files/config.yaml +51 -0
- wandb/run-20241028_090044-f9fzz8iy/files/output.log +15 -0
- wandb/run-20241028_090044-f9fzz8iy/files/requirements.txt +178 -0
- wandb/run-20241028_090044-f9fzz8iy/files/wandb-metadata.json +60 -0
- wandb/run-20241028_090044-f9fzz8iy/files/wandb-summary.json +1 -0
- wandb/run-20241028_090044-f9fzz8iy/logs/debug-core.log +13 -0
- wandb/run-20241028_090044-f9fzz8iy/logs/debug-internal.log +22 -0
- wandb/run-20241028_090044-f9fzz8iy/logs/debug.log +27 -0
- wandb/run-20241028_090044-f9fzz8iy/run-f9fzz8iy.wandb +0 -0
- wandb/run-20241028_090149-4jbvn26d/files/code/main.py +124 -0
.gitattributes
CHANGED
@@ -33,3 +33,45 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
cnn_dailymail/test.csv filter=lfs diff=lfs merge=lfs -text
|
37 |
+
cnn_dailymail/train.csv filter=lfs diff=lfs merge=lfs -text
|
38 |
+
cnn_dailymail/validation.csv filter=lfs diff=lfs merge=lfs -text
|
39 |
+
wandb/run-20241028_090149-4jbvn26d/run-4jbvn26d.wandb filter=lfs diff=lfs merge=lfs -text
|
40 |
+
wandb/run-20241028_121715-5kjpxsew/run-5kjpxsew.wandb filter=lfs diff=lfs merge=lfs -text
|
41 |
+
wandb/run-20241028_121925-nc6un2i3/run-nc6un2i3.wandb filter=lfs diff=lfs merge=lfs -text
|
42 |
+
wandb/run-20241028_123711-98pjnrqo/run-98pjnrqo.wandb filter=lfs diff=lfs merge=lfs -text
|
43 |
+
wandb/run-20241028_130604-p34k5wnh/run-p34k5wnh.wandb filter=lfs diff=lfs merge=lfs -text
|
44 |
+
wandb/run-20241028_164041-fs40r39w/run-fs40r39w.wandb filter=lfs diff=lfs merge=lfs -text
|
45 |
+
wandb/run-20241028_165506-q960l24p/run-q960l24p.wandb filter=lfs diff=lfs merge=lfs -text
|
46 |
+
wandb/run-20241028_173226-2kgb0f9e/run-2kgb0f9e.wandb filter=lfs diff=lfs merge=lfs -text
|
47 |
+
wandb/run-20241028_183140-jbr9b02q/run-jbr9b02q.wandb filter=lfs diff=lfs merge=lfs -text
|
48 |
+
wandb/run-20241028_184435-8g9y4qy1/run-8g9y4qy1.wandb filter=lfs diff=lfs merge=lfs -text
|
49 |
+
wandb/run-20241028_191117-1eqkbgu2/run-1eqkbgu2.wandb filter=lfs diff=lfs merge=lfs -text
|
50 |
+
wandb/run-20241028_193646-r4qgqj1u/run-r4qgqj1u.wandb filter=lfs diff=lfs merge=lfs -text
|
51 |
+
wandb/run-20241028_195038-eakkycgw/run-eakkycgw.wandb filter=lfs diff=lfs merge=lfs -text
|
52 |
+
wandb/run-20241028_201220-6xkcnm4u/run-6xkcnm4u.wandb filter=lfs diff=lfs merge=lfs -text
|
53 |
+
wandb/run-20241028_202354-zr7kt8eh/run-zr7kt8eh.wandb filter=lfs diff=lfs merge=lfs -text
|
54 |
+
wandb/run-20241028_210645-j0itro1g/run-j0itro1g.wandb filter=lfs diff=lfs merge=lfs -text
|
55 |
+
wandb/run-20241028_210725-ho9p1f0s/run-ho9p1f0s.wandb filter=lfs diff=lfs merge=lfs -text
|
56 |
+
wandb/run-20241028_212229-ovcu6nj3/run-ovcu6nj3.wandb filter=lfs diff=lfs merge=lfs -text
|
57 |
+
wandb/run-20241028_212304-uq1xozlm/run-uq1xozlm.wandb filter=lfs diff=lfs merge=lfs -text
|
58 |
+
wandb/run-20241028_213435-qow8er7m/run-qow8er7m.wandb filter=lfs diff=lfs merge=lfs -text
|
59 |
+
wandb/run-20241028_213754-m1kaqlt1/run-m1kaqlt1.wandb filter=lfs diff=lfs merge=lfs -text
|
60 |
+
wandb/run-20241028_221407-dv4g6q0z/run-dv4g6q0z.wandb filter=lfs diff=lfs merge=lfs -text
|
61 |
+
wandb/run-20241028_221423-pxyf2xri/run-pxyf2xri.wandb filter=lfs diff=lfs merge=lfs -text
|
62 |
+
wandb/run-20241028_222813-k1xslrgl/run-k1xslrgl.wandb filter=lfs diff=lfs merge=lfs -text
|
63 |
+
wandb/run-20241028_223604-ucawfmok/run-ucawfmok.wandb filter=lfs diff=lfs merge=lfs -text
|
64 |
+
wandb/run-20241028_224822-nlehsykg/run-nlehsykg.wandb filter=lfs diff=lfs merge=lfs -text
|
65 |
+
wandb/run-20241028_225822-xuv6uhuc/run-xuv6uhuc.wandb filter=lfs diff=lfs merge=lfs -text
|
66 |
+
wandb/run-20241029_130621-0asef00f/run-0asef00f.wandb filter=lfs diff=lfs merge=lfs -text
|
67 |
+
wandb/run-20241029_134007-8a5jhu4s/run-8a5jhu4s.wandb filter=lfs diff=lfs merge=lfs -text
|
68 |
+
wandb/run-20241029_141057-uk43a4xl/run-uk43a4xl.wandb filter=lfs diff=lfs merge=lfs -text
|
69 |
+
wandb/run-20241029_151508-ya0e0d5g/run-ya0e0d5g.wandb filter=lfs diff=lfs merge=lfs -text
|
70 |
+
wandb/run-20241029_155952-fb5ojuk9/run-fb5ojuk9.wandb filter=lfs diff=lfs merge=lfs -text
|
71 |
+
wandb/run-20241029_182402-3dknsv44/run-3dknsv44.wandb filter=lfs diff=lfs merge=lfs -text
|
72 |
+
wandb/run-20241029_182613-mibkz7zt/run-mibkz7zt.wandb filter=lfs diff=lfs merge=lfs -text
|
73 |
+
wandb/run-20241029_183624-haor84lw/run-haor84lw.wandb filter=lfs diff=lfs merge=lfs -text
|
74 |
+
wandb/run-20241029_190201-8uotupup/run-8uotupup.wandb filter=lfs diff=lfs merge=lfs -text
|
75 |
+
wandb/run-20241029_190305-legb7y4v/run-legb7y4v.wandb filter=lfs diff=lfs merge=lfs -text
|
76 |
+
wandb/run-20241029_192824-grmmhjzz/run-grmmhjzz.wandb filter=lfs diff=lfs merge=lfs -text
|
77 |
+
wandb/run-20241029_193507-ujpie5pz/run-ujpie5pz.wandb filter=lfs diff=lfs merge=lfs -text
|
__pycache__/main.cpython-310.pyc
ADDED
Binary file (8.91 kB). View file
|
|
__pycache__/main1.cpython-310.pyc
ADDED
Binary file (9.74 kB). View file
|
|
__pycache__/main2.cpython-310.pyc
ADDED
Binary file (10.8 kB). View file
|
|
__pycache__/utils.cpython-310.pyc
ADDED
Binary file (4.03 kB). View file
|
|
cnn_dailymail/test.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:69e091606539b415f768de75dd026bcee37b35b5b50b6088d0a6d0e017559d29
|
3 |
+
size 49890690
|
cnn_dailymail/train.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd4ba100d4da4c5fe5414a590a5eca9cb47494e044c179a3aadb96cead676ab7
|
3 |
+
size 1262015264
|
cnn_dailymail/validation.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:97dedb1d6fd51f94f74e1e9dfd3d7e175fc6b4f4417d9c1cdc33abd0b5fe54a1
|
3 |
+
size 57691847
|
last_layer.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import inf
|
2 |
+
# from utils import *
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import torch.utils
|
7 |
+
import torch.utils.data
|
8 |
+
# from utils import MyDataset, custom_collate
|
9 |
+
from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
|
10 |
+
import wandb
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
import einops
|
14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
|
15 |
+
|
16 |
+
np.random.seed(123)
|
17 |
+
torch.manual_seed(123)
|
18 |
+
torch.cuda.random.manual_seed(123)
|
19 |
+
|
20 |
+
import lightning as L
|
21 |
+
import utils
|
22 |
+
from torchmetrics.text.rouge import ROUGEScore
|
23 |
+
def top_p_sampling(logits, p=0.9, temperature=0.5):
|
24 |
+
|
25 |
+
# Apply temperature scaling
|
26 |
+
logits = logits / temperature
|
27 |
+
|
28 |
+
# Sort logits and get cumulative probabilities
|
29 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
30 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
31 |
+
|
32 |
+
# Create a mask for probabilities above the threshold
|
33 |
+
sorted_indices_to_remove = cumulative_probs > p
|
34 |
+
# Shift the indices to the right to keep also the smallest p
|
35 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
36 |
+
sorted_indices_to_remove[..., 0] = 0
|
37 |
+
|
38 |
+
# Scatter sorted indices to original indices with mask
|
39 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
40 |
+
logits[indices_to_remove] = float('-inf') # Set unwanted logits to -inf
|
41 |
+
|
42 |
+
# Sample from the remaining logits
|
43 |
+
probs = F.softmax(logits, dim=-1)
|
44 |
+
sampled_indices = torch.multinomial(probs, num_samples=1)
|
45 |
+
sampled_indices = sampled_indices.squeeze(1)
|
46 |
+
|
47 |
+
return sampled_indices
|
48 |
+
|
49 |
+
class PromptTuningModel(nn.Module):
|
50 |
+
def __init__(self, num_prompts=6):
|
51 |
+
super().__init__()
|
52 |
+
self.num_prompts = num_prompts
|
53 |
+
|
54 |
+
self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
|
55 |
+
self.model.requires_grad_(False)
|
56 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
57 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
58 |
+
self.tokenizer.add_special_tokens({'pad_token': '[START]'})
|
59 |
+
|
60 |
+
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
|
61 |
+
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
|
62 |
+
|
63 |
+
|
64 |
+
self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
|
65 |
+
|
66 |
+
tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
|
67 |
+
token_embedding = self.model.transformer.wte(tmp[0])
|
68 |
+
self.token_embedding = token_embedding
|
69 |
+
for _ in range(num_prompts//3-1):
|
70 |
+
self.token_embedding = torch.cat([self.token_embedding, token_embedding])
|
71 |
+
|
72 |
+
# print(self.token_embedding.shape)
|
73 |
+
data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
|
74 |
+
self.learnable_prompt = nn.Parameter(data, requires_grad=True)
|
75 |
+
|
76 |
+
# @torch.compile
|
77 |
+
def forward(self, X, y):
|
78 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
79 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
80 |
+
|
81 |
+
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
|
82 |
+
# mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
|
83 |
+
# print(mask.shape)
|
84 |
+
# labels = torch.where(y == 50257, -100, y)
|
85 |
+
# ignore = torch.ones([X.shape[0], self.num_prompts], dtype=torch.long, device=X.device)*-100
|
86 |
+
# labels = torch.cat([ignore, labels], dim=1)
|
87 |
+
out = self.model(inputs_embeds = embeddings)
|
88 |
+
# print("Out.loss:", out.loss)
|
89 |
+
logits = out.logits[:,self.num_prompts:]
|
90 |
+
return logits
|
91 |
+
|
92 |
+
def generate_new(self, X):
|
93 |
+
batch_size = X.shape[0]
|
94 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
95 |
+
embeddings = self.model.transformer.wte(X)
|
96 |
+
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(batch_size, 1, 1), embeddings], dim=1)
|
97 |
+
|
98 |
+
cnt = 0
|
99 |
+
past_key_values = None
|
100 |
+
generated_ids = torch.tensor([], dtype=torch.long, device=X.device).view(batch_size, 0) # Store all generated tokens
|
101 |
+
|
102 |
+
while cnt < 196:
|
103 |
+
|
104 |
+
out = self.model(inputs_embeds=embeddings, use_cache=True, past_key_values=past_key_values)
|
105 |
+
past_key_values = out.past_key_values
|
106 |
+
# print(cnt)
|
107 |
+
if cnt == 0:
|
108 |
+
logits = out.logits[:, self.num_prompts:]
|
109 |
+
else:
|
110 |
+
logits = out.logits
|
111 |
+
|
112 |
+
logits[:, :, 50257:] = -1e4 # Apply after slicing for correct dimensions
|
113 |
+
|
114 |
+
next_token_ids = top_p_sampling(logits[:, -1, :])
|
115 |
+
# next_token_ids will have shape (batch_size,)
|
116 |
+
print(next_token_ids.shape)
|
117 |
+
exit()
|
118 |
+
generated_ids = torch.cat([generated_ids, next_token_ids.unsqueeze(-1)], dim=-1)
|
119 |
+
|
120 |
+
embeddings = self.model.transformer.wte(next_token_ids) # Correctly obtains embeddings for current batch
|
121 |
+
|
122 |
+
|
123 |
+
cnt += 1
|
124 |
+
|
125 |
+
#Check if all sequences have reached the <end> token
|
126 |
+
if torch.all((generated_ids == self.eot.item()).any(dim=-1)): # Check each sequence independently
|
127 |
+
break
|
128 |
+
|
129 |
+
return generated_ids
|
130 |
+
def generate(self, X):
|
131 |
+
# Only bs = 1
|
132 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
133 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
134 |
+
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
|
135 |
+
|
136 |
+
cnt = 0
|
137 |
+
past_key_values = None
|
138 |
+
final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
|
139 |
+
while cnt < 196:
|
140 |
+
out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
|
141 |
+
# print(cnt, out.logits.shape)
|
142 |
+
past_key_values = out.past_key_values
|
143 |
+
if cnt == 0:
|
144 |
+
logits = out.logits[:,self.num_prompts:]
|
145 |
+
logits[:,:, 50257:] = -1e4
|
146 |
+
|
147 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:, None]
|
148 |
+
|
149 |
+
# print(output.shape)
|
150 |
+
|
151 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
152 |
+
# print(output.shape)
|
153 |
+
embeddings = self.model.transformer.wte(output)
|
154 |
+
# print(embeddings.shape)
|
155 |
+
|
156 |
+
|
157 |
+
else:
|
158 |
+
# print(logits.shape)
|
159 |
+
logits = out.logits
|
160 |
+
logits[:, :, 50257:] = -1e4
|
161 |
+
|
162 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
|
163 |
+
# print(output)
|
164 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
165 |
+
# print(final_prediction.shape, 'final')
|
166 |
+
embeddings = self.model.transformer.wte(output)
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
cnt += 1
|
171 |
+
# print(output.shape, self.eot.shape)
|
172 |
+
if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
|
173 |
+
break
|
174 |
+
|
175 |
+
return final_prediction
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
class LMModel(nn.Module):
|
181 |
+
def __init__(self, num_prompts=0):
|
182 |
+
super().__init__()
|
183 |
+
self.num_prompts = num_prompts
|
184 |
+
|
185 |
+
self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
|
186 |
+
self.model.requires_grad_(False)
|
187 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
188 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
189 |
+
self.tokenizer.add_special_tokens({'pad_token': '[START]'})
|
190 |
+
|
191 |
+
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
|
192 |
+
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
|
193 |
+
|
194 |
+
|
195 |
+
self.model.lm_head.requires_grad_(True)
|
196 |
+
|
197 |
+
# @torch.compile
|
198 |
+
def forward(self, X, y):
|
199 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
200 |
+
logits = self.model(inputs_embeds = embeddings).logits
|
201 |
+
return logits
|
202 |
+
|
203 |
+
def generate(self, X):
|
204 |
+
# Only bs = 1
|
205 |
+
# self.learnable_prompt = self.learnable_prompt.to(X.device)
|
206 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
207 |
+
# embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
|
208 |
+
|
209 |
+
cnt = 0
|
210 |
+
past_key_values = None
|
211 |
+
final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
|
212 |
+
while cnt < 196:
|
213 |
+
out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
|
214 |
+
# print(cnt, out.logits.shape)
|
215 |
+
past_key_values = out.past_key_values
|
216 |
+
if cnt == 0:
|
217 |
+
logits = out.logits[:,self.num_prompts:]
|
218 |
+
logits[:,:, 50257:] = -1e4
|
219 |
+
|
220 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
|
221 |
+
|
222 |
+
# print(output.shape)
|
223 |
+
|
224 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
225 |
+
# print(output.shape)
|
226 |
+
embeddings = self.model.transformer.wte(output)
|
227 |
+
# print(embeddings.shape)
|
228 |
+
|
229 |
+
|
230 |
+
else:
|
231 |
+
# print(logits.shape)
|
232 |
+
logits = out.logits
|
233 |
+
logits[:, :, 50257:] = -1e4
|
234 |
+
|
235 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
|
236 |
+
# print(output)
|
237 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
238 |
+
# print(final_prediction.shape, 'final')
|
239 |
+
embeddings = self.model.transformer.wte(output)
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
cnt += 1
|
244 |
+
# print(output.shape, self.eot.shape)
|
245 |
+
if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
|
246 |
+
break
|
247 |
+
|
248 |
+
return final_prediction
|
249 |
+
|
250 |
+
def zero_after_x(tensor, x):
|
251 |
+
"""
|
252 |
+
Zeros out all elements in each row of a 2D tensor after the first occurrence of x.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
tensor: The input 2D tensor.
|
256 |
+
x: The value after which to zero out elements.
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
A new tensor with elements zeroed out after x.
|
260 |
+
"""
|
261 |
+
|
262 |
+
mask = (tensor == x).cumsum(dim=1) > 0 # Create a cumulative mask
|
263 |
+
result = tensor.where(~mask, torch.ones_like(tensor, dtype=torch.long)*x) #zero out where mask is True
|
264 |
+
|
265 |
+
return result
|
266 |
+
|
267 |
+
class LitModelPromptTuning(L.LightningModule):
|
268 |
+
def __init__(self, model, lr=1e-4, temperature):
|
269 |
+
super().__init__()
|
270 |
+
self.model = model
|
271 |
+
self.lr = lr
|
272 |
+
self.model.temperature = temperature
|
273 |
+
tokenize_to_strings = lambda text: self.model.tokenizer.convert_ids_to_tokens(self.model.tokenizer(text)["input_ids"])
|
274 |
+
self.rouge = ROUGEScore(tokenizer=tokenize_to_strings)
|
275 |
+
|
276 |
+
self.save_hyperparameters(ignore=['model'])
|
277 |
+
|
278 |
+
|
279 |
+
def training_step(self, batch, batch_idx):
|
280 |
+
X, y = batch
|
281 |
+
# for i,j in zip(X[1], y[1]):
|
282 |
+
# print(i.item(),j.item())
|
283 |
+
|
284 |
+
logits = self.model(X, y)
|
285 |
+
|
286 |
+
logits[:,:, 50257:] = -1e4
|
287 |
+
# print(X.shape, y.shape, logits.shape)
|
288 |
+
loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
|
289 |
+
# print(loss)
|
290 |
+
# prob = logits.softmax(dim=-1)[0,-300:]
|
291 |
+
# target = y[0, -300:]
|
292 |
+
# print('logits',logits[0,-300:][torch.arange(target.numel()), target])
|
293 |
+
# print(prob[torch.arange(target.numel()), target])
|
294 |
+
# print(prob.argmax(dim=-1), target, X[0, -300:])
|
295 |
+
# print(self.model.tokenizer.decode(prob.argmax(dim=-1)), 'gap', self.model.tokenizer.decode(target))
|
296 |
+
# print(self.model.tokenizer.decode(X[0,-300:]))
|
297 |
+
# x = F.cross_entropy(logits[0,:-1,:].reshape(-1, logits.shape[-1]), target=y[0,:-1].reshape(-1), ignore_index=50257, reduction='none')
|
298 |
+
# print(x[-20:])
|
299 |
+
# print(self.model.pad)
|
300 |
+
|
301 |
+
# print(X[0, -300:].shape, target.shape, prob.argmax(dim=-1).shape)
|
302 |
+
# for i,j,k in zip(X[0, -300:], target, prob.argmax(dim=-1)):
|
303 |
+
# print(self.model.tokenizer.decode(i),'\tx ',self.model.tokenizer.decode(j),'\tx ',self.model.tokenizer.decode(k))
|
304 |
+
|
305 |
+
|
306 |
+
# exit()
|
307 |
+
|
308 |
+
self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
|
309 |
+
return loss
|
310 |
+
|
311 |
+
|
312 |
+
def validation_step(self, batch, batch_idx):
|
313 |
+
X, y = batch
|
314 |
+
|
315 |
+
logits = self.model(X, y)
|
316 |
+
logits[:,:, 50257:] = -1e4
|
317 |
+
# print(X.shape, y.shape, logits.shape)
|
318 |
+
loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
|
319 |
+
|
320 |
+
self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
321 |
+
return loss
|
322 |
+
|
323 |
+
def on_test_epoch_start(self, ):
|
324 |
+
self.all_text = []
|
325 |
+
self.predicted_text = []
|
326 |
+
|
327 |
+
def test_step(self, batch, batch_idx):
|
328 |
+
if batch_idx == 0:
|
329 |
+
return
|
330 |
+
X, y = batch
|
331 |
+
# print(self.model.tokenizer.batch_decode(X))
|
332 |
+
# print(X.shape)
|
333 |
+
# with torch.no_grad()
|
334 |
+
out = self.model.generate(X)
|
335 |
+
# out = zero_after_x(out, self.model.eot.item())
|
336 |
+
# print(out.shape, y.shape)
|
337 |
+
# print(out, y)
|
338 |
+
pred = self.model.tokenizer.batch_decode(out, skip_special_tokens=True)
|
339 |
+
gt = self.model.tokenizer.batch_decode(y, skip_special_tokens=True)
|
340 |
+
|
341 |
+
|
342 |
+
print(pred)
|
343 |
+
print('GAP')
|
344 |
+
print(gt)
|
345 |
+
final_score = 0
|
346 |
+
|
347 |
+
for p,g in zip(pred, gt):
|
348 |
+
score = self.rouge(p, g, )
|
349 |
+
print(score)
|
350 |
+
# exit()
|
351 |
+
|
352 |
+
|
353 |
+
self.log_dict(score, on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
354 |
+
|
355 |
+
|
356 |
+
|
357 |
+
def configure_optimizers(self):
|
358 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
|
359 |
+
return optimizer
|
360 |
+
|
361 |
+
from lightning.pytorch.loggers import WandbLogger
|
362 |
+
if __name__ == '__main__':
|
363 |
+
torch.set_float32_matmul_precision('medium')
|
364 |
+
dl_train, dl_val, dl_test = utils.import_data(bs=25, fraction=0.1)
|
365 |
+
# gpt_model = PromptTuningModel(num_prompts=12)
|
366 |
+
gpt_model = LMModel(num_prompts=0)
|
367 |
+
|
368 |
+
# gpt_model = torch.compile(gpt_model)
|
369 |
+
model = LitModelPromptTuning(
|
370 |
+
model=gpt_model,
|
371 |
+
lr=1e-4,
|
372 |
+
temperature=0.9,
|
373 |
+
epoch = 10
|
374 |
+
|
375 |
+
|
376 |
+
)
|
377 |
+
print('Training')
|
378 |
+
|
379 |
+
logger = WandbLogger(project='Anlp-3')
|
380 |
+
trainer = L.Trainer(
|
381 |
+
accelerator='gpu',
|
382 |
+
# limit_train_batches=1,
|
383 |
+
# strategy='auto',
|
384 |
+
# strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
|
385 |
+
devices=[2],
|
386 |
+
default_root_dir=f'./logs/', # Tensorflow can be used to viz
|
387 |
+
num_nodes=1,
|
388 |
+
num_sanity_val_steps=1, # runs a validation step before stating training
|
389 |
+
precision='bf16-mixed', # we use half precision to reduce memory usage
|
390 |
+
max_epochs=5,
|
391 |
+
check_val_every_n_epoch=1, # run validation every epoch
|
392 |
+
log_every_n_steps=20,
|
393 |
+
logger=logger,
|
394 |
+
# detect_anomaly=True,
|
395 |
+
)
|
396 |
+
|
397 |
+
# trainer.test(model, dataloaders=dl_test)
|
398 |
+
trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
|
399 |
+
trainer.test(model, dataloaders=dl_test)
|
main.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import inf
|
2 |
+
# from utils import *
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import torch.utils
|
7 |
+
import torch.utils.data
|
8 |
+
# from utils import MyDataset, custom_collate
|
9 |
+
from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
|
10 |
+
import wandb
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
import einops
|
14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
|
15 |
+
|
16 |
+
np.random.seed(123)
|
17 |
+
torch.manual_seed(123)
|
18 |
+
torch.cuda.random.manual_seed(123)
|
19 |
+
|
20 |
+
import lightning as L
|
21 |
+
import utils
|
22 |
+
from torchmetrics.text.rouge import ROUGEScore
|
23 |
+
def top_p_sampling(logits, p=0.9, temperature=0.5):
|
24 |
+
|
25 |
+
# Apply temperature scaling
|
26 |
+
logits = logits / temperature
|
27 |
+
|
28 |
+
# Sort logits and get cumulative probabilities
|
29 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
30 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
31 |
+
|
32 |
+
# Create a mask for probabilities above the threshold
|
33 |
+
sorted_indices_to_remove = cumulative_probs > p
|
34 |
+
# Shift the indices to the right to keep also the smallest p
|
35 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
36 |
+
sorted_indices_to_remove[..., 0] = 0
|
37 |
+
|
38 |
+
# Scatter sorted indices to original indices with mask
|
39 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
40 |
+
logits[indices_to_remove] = float('-inf') # Set unwanted logits to -inf
|
41 |
+
|
42 |
+
# Sample from the remaining logits
|
43 |
+
probs = F.softmax(logits, dim=-1)
|
44 |
+
sampled_indices = torch.multinomial(probs, num_samples=1)
|
45 |
+
sampled_indices = sampled_indices.squeeze(1)
|
46 |
+
|
47 |
+
return sampled_indices
|
48 |
+
|
49 |
+
class PromptTuningModel(nn.Module):
|
50 |
+
def __init__(self, num_prompts=6):
|
51 |
+
super().__init__()
|
52 |
+
self.num_prompts = num_prompts
|
53 |
+
|
54 |
+
self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
|
55 |
+
# self.model.generation_config.cache_implementation = "static"
|
56 |
+
# self.model.generation_config.max_new_tokens = 256
|
57 |
+
self.model.requires_grad_(False)
|
58 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
59 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
60 |
+
self.tokenizer.add_special_tokens({'cls_token': '[START]'})
|
61 |
+
|
62 |
+
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
|
63 |
+
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
|
64 |
+
self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
|
65 |
+
|
66 |
+
|
67 |
+
self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
|
68 |
+
|
69 |
+
tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
|
70 |
+
token_embedding = self.model.transformer.wte(tmp[0])
|
71 |
+
self.token_embedding = token_embedding
|
72 |
+
for _ in range(num_prompts//3-1):
|
73 |
+
self.token_embedding = torch.cat([self.token_embedding, token_embedding])
|
74 |
+
|
75 |
+
# print(self.token_embedding.shape)
|
76 |
+
data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
|
77 |
+
self.learnable_prompt = nn.Parameter(data, requires_grad=True)
|
78 |
+
|
79 |
+
# self.model.transformer.wte.weight[self.start].requires_grad = True
|
80 |
+
|
81 |
+
# @torch.compile
|
82 |
+
def forward(self, X, y):
|
83 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
84 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
85 |
+
|
86 |
+
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
|
87 |
+
# mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
|
88 |
+
# print(mask.shape)
|
89 |
+
# labels = torch.where(y == 50257, -100, y)
|
90 |
+
# ignore = torch.ones([X.shape[0], self.num_prompts], dtype=torch.long, device=X.device)*-100
|
91 |
+
# labels = torch.cat([ignore, labels], dim=1)
|
92 |
+
out = self.model(inputs_embeds = embeddings)
|
93 |
+
# print("Out.loss:", out.loss)
|
94 |
+
logits = out.logits[:,self.num_prompts:]
|
95 |
+
return logits
|
96 |
+
|
97 |
+
def generate_new(self, X):
|
98 |
+
batch_size = X.shape[0]
|
99 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
100 |
+
embeddings = self.model.transformer.wte(X)
|
101 |
+
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(batch_size, 1, 1), embeddings], dim=1)
|
102 |
+
|
103 |
+
cnt = 0
|
104 |
+
past_key_values = None
|
105 |
+
generated_ids = torch.tensor([], dtype=torch.long, device=X.device).view(batch_size, 0) # Store all generated tokens
|
106 |
+
|
107 |
+
while cnt < 196:
|
108 |
+
|
109 |
+
out = self.model(inputs_embeds=embeddings, use_cache=True, past_key_values=past_key_values)
|
110 |
+
past_key_values = out.past_key_values
|
111 |
+
# print(cnt)
|
112 |
+
if cnt == 0:
|
113 |
+
logits = out.logits[:, self.num_prompts:]
|
114 |
+
else:
|
115 |
+
logits = out.logits
|
116 |
+
|
117 |
+
logits[:, :, 50257:] = -1e4 # Apply after slicing for correct dimensions
|
118 |
+
|
119 |
+
next_token_ids = top_p_sampling(logits[:, -1, :])
|
120 |
+
# next_token_ids will have shape (batch_size,)
|
121 |
+
print(next_token_ids.shape)
|
122 |
+
exit()
|
123 |
+
generated_ids = torch.cat([generated_ids, next_token_ids.unsqueeze(-1)], dim=-1)
|
124 |
+
|
125 |
+
embeddings = self.model.transformer.wte(next_token_ids) # Correctly obtains embeddings for current batch
|
126 |
+
|
127 |
+
|
128 |
+
cnt += 1
|
129 |
+
|
130 |
+
#Check if all sequences have reached the <end> token
|
131 |
+
if torch.all((generated_ids == self.eot.item()).any(dim=-1)): # Check each sequence independently
|
132 |
+
break
|
133 |
+
|
134 |
+
return generated_ids
|
135 |
+
def generate(self, X):
|
136 |
+
# Only bs = 1
|
137 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
138 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
139 |
+
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
|
140 |
+
|
141 |
+
cnt = 0
|
142 |
+
past_key_values = None
|
143 |
+
final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
|
144 |
+
while cnt < 196:
|
145 |
+
out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
|
146 |
+
# print(cnt, out.logits.shape)
|
147 |
+
past_key_values = out.past_key_values
|
148 |
+
if cnt == 0:
|
149 |
+
logits = out.logits[:,self.num_prompts:]
|
150 |
+
logits[:,:, 50257:] = -1e4
|
151 |
+
|
152 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
|
153 |
+
|
154 |
+
# print(output.shape)
|
155 |
+
|
156 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
157 |
+
# print(output.shape)
|
158 |
+
embeddings = self.model.transformer.wte(output)
|
159 |
+
# print(embeddings.shape)
|
160 |
+
|
161 |
+
|
162 |
+
else:
|
163 |
+
# print(logits.shape)
|
164 |
+
logits = out.logits
|
165 |
+
logits[:, :, 50257:] = -1e4
|
166 |
+
|
167 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
|
168 |
+
# print(output)
|
169 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
170 |
+
# print(final_prediction.shape, 'final')
|
171 |
+
embeddings = self.model.transformer.wte(output)
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
cnt += 1
|
176 |
+
# print(output.shape, self.eot.shape)
|
177 |
+
if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
|
178 |
+
break
|
179 |
+
|
180 |
+
return final_prediction
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
class LMModel(nn.Module):
|
186 |
+
def __init__(self, num_prompts=6):
|
187 |
+
super().__init__()
|
188 |
+
self.num_prompts = num_prompts
|
189 |
+
|
190 |
+
self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
|
191 |
+
self.model.requires_grad_(False)
|
192 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
193 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
194 |
+
self.tokenizer.add_special_tokens({'pad_token': '[START]'})
|
195 |
+
|
196 |
+
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
|
197 |
+
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
|
198 |
+
|
199 |
+
|
200 |
+
self.model.lm_head.requires_grad_(True)
|
201 |
+
|
202 |
+
# @torch.compile
|
203 |
+
def forward(self, X):
|
204 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
205 |
+
logits = self.model(inputs_embeds = embeddings).logits
|
206 |
+
return logits
|
207 |
+
|
208 |
+
def zero_after_x(arr, x):
|
209 |
+
"""
|
210 |
+
Zeros out all elements in each row of a 2D tensor after the first occurrence of x.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
tensor: The input 2D tensor.
|
214 |
+
x: The value after which to zero out elements.
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
A new tensor with elements zeroed out after x.
|
218 |
+
"""
|
219 |
+
|
220 |
+
mask = (arr == x).cumsum(dim=1) > 0 # Create a cumulative mask
|
221 |
+
result = torch.where(mask, x, arr) #zero out where mask is True
|
222 |
+
|
223 |
+
return result
|
224 |
+
|
225 |
+
class LitModelPromptTuning(L.LightningModule):
|
226 |
+
def __init__(self, model, temperature, epoch, lr=1e-4):
|
227 |
+
super().__init__()
|
228 |
+
self.model = model
|
229 |
+
self.lr = lr
|
230 |
+
self.model.temperature = temperature
|
231 |
+
self.epoch = epoch
|
232 |
+
self.temperature = temperature
|
233 |
+
|
234 |
+
tokenize_to_strings = lambda text: self.model.tokenizer.convert_ids_to_tokens(self.model.tokenizer(text)["input_ids"])
|
235 |
+
self.rouge = ROUGEScore(tokenizer=tokenize_to_strings)
|
236 |
+
|
237 |
+
self.save_hyperparameters(ignore=['model'])
|
238 |
+
|
239 |
+
|
240 |
+
def training_step(self, batch, batch_idx):
|
241 |
+
X, y = batch
|
242 |
+
# for i,j in zip(X[1], y[1]):
|
243 |
+
# print(i.item(),j.item())
|
244 |
+
|
245 |
+
logits = self.model(X, y)
|
246 |
+
|
247 |
+
logits[:,:, 50257:] = -1e4
|
248 |
+
# print(X.shape, y.shape, logits.shape)
|
249 |
+
loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
|
250 |
+
# print(loss)
|
251 |
+
# prob = logits.softmax(dim=-1)[0,-300:]
|
252 |
+
# target = y[0, -300:]
|
253 |
+
# print('logits',logits[0,-300:][torch.arange(target.numel()), target])
|
254 |
+
# print(prob[torch.arange(target.numel()), target])
|
255 |
+
# print(prob.argmax(dim=-1), target, X[0, -300:])
|
256 |
+
# print(self.model.tokenizer.decode(prob.argmax(dim=-1)), 'gap', self.model.tokenizer.decode(target))
|
257 |
+
# print(self.model.tokenizer.decode(X[0,-300:]))
|
258 |
+
# x = F.cross_entropy(logits[0,:-1,:].reshape(-1, logits.shape[-1]), target=y[0,:-1].reshape(-1), ignore_index=50257, reduction='none')
|
259 |
+
# print(x[-20:])
|
260 |
+
# print(self.model.pad)
|
261 |
+
|
262 |
+
# print(X[0, -300:].shape, target.shape, prob.argmax(dim=-1).shape)
|
263 |
+
# for i,j,k in zip(X[0, -300:], target, prob.argmax(dim=-1)):
|
264 |
+
# print(self.model.tokenizer.decode(i),'\tx ',self.model.tokenizer.decode(j),'\tx ',self.model.tokenizer.decode(k))
|
265 |
+
|
266 |
+
|
267 |
+
# exit()
|
268 |
+
|
269 |
+
self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
|
270 |
+
return loss
|
271 |
+
|
272 |
+
|
273 |
+
def validation_step(self, batch, batch_idx):
|
274 |
+
X, y = batch
|
275 |
+
|
276 |
+
logits = self.model(X, y)
|
277 |
+
logits[:,:, 50257:] = -1e4
|
278 |
+
# print(X.shape, y.shape, logits.shape)
|
279 |
+
loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
|
280 |
+
|
281 |
+
self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
282 |
+
return loss
|
283 |
+
|
284 |
+
def on_test_epoch_start(self, ):
|
285 |
+
self.all_text = []
|
286 |
+
self.predicted_text = []
|
287 |
+
|
288 |
+
def test_step(self, batch, batch_idx):
|
289 |
+
if batch_idx == 0:
|
290 |
+
return
|
291 |
+
X, y = batch
|
292 |
+
# print(self.model.tokenizer.batch_decode(X))
|
293 |
+
# print(X.shape)
|
294 |
+
# with torch.no_grad()
|
295 |
+
out = self.model.generate(X)
|
296 |
+
# out = zero_after_x(out, self.model.eot.to(X.device))
|
297 |
+
# print(out.shape, y.shape)
|
298 |
+
# print(out, y)
|
299 |
+
pred = self.model.tokenizer.batch_decode(out, skip_special_tokens=False)
|
300 |
+
gt = self.model.tokenizer.batch_decode(y, skip_special_tokens=False)
|
301 |
+
|
302 |
+
|
303 |
+
print(pred)
|
304 |
+
print('GAP')
|
305 |
+
print(gt)
|
306 |
+
final_score = 0
|
307 |
+
|
308 |
+
for p,g in zip(pred, gt):
|
309 |
+
score = self.rouge(p, g, )
|
310 |
+
print(score)
|
311 |
+
# exit()
|
312 |
+
|
313 |
+
|
314 |
+
self.log_dict(score, on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
+
def configure_optimizers(self):
|
319 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
|
320 |
+
return optimizer
|
321 |
+
|
322 |
+
from lightning.pytorch.loggers import WandbLogger
|
323 |
+
if __name__ == '__main__':
|
324 |
+
|
325 |
+
train = False
|
326 |
+
|
327 |
+
torch.set_float32_matmul_precision('medium')
|
328 |
+
dl_train, dl_val, dl_test = utils.import_data(bs=24, fraction=1)
|
329 |
+
if train:
|
330 |
+
gpt_model = PromptTuningModel(num_prompts=24)
|
331 |
+
gpt_model = torch.compile(gpt_model)
|
332 |
+
else:
|
333 |
+
gpt_model = torch.load('./model0.bin')
|
334 |
+
# gpt_model = LMModel(num_prompts=12)
|
335 |
+
|
336 |
+
|
337 |
+
model = LitModelPromptTuning(
|
338 |
+
model=gpt_model,
|
339 |
+
lr=1e-3,
|
340 |
+
temperature=0.9,
|
341 |
+
epoch = 5
|
342 |
+
)
|
343 |
+
print('Training')
|
344 |
+
|
345 |
+
logger = WandbLogger(project='Anlp-3')
|
346 |
+
trainer = L.Trainer(
|
347 |
+
accelerator='gpu',
|
348 |
+
# limit_train_batches=1,
|
349 |
+
# strategy='auto',
|
350 |
+
# strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
|
351 |
+
devices=1,
|
352 |
+
default_root_dir=f'./logs/', # Tensorflow can be used to viz
|
353 |
+
num_nodes=1,
|
354 |
+
num_sanity_val_steps=1, # runs a validation step before stating training
|
355 |
+
precision='bf16-mixed', # we use half precision to reduce memory usage
|
356 |
+
max_epochs=5,
|
357 |
+
check_val_every_n_epoch=1, # run validation every epoch
|
358 |
+
log_every_n_steps=20,
|
359 |
+
logger=logger,
|
360 |
+
# detect_anomaly=True,
|
361 |
+
)
|
362 |
+
|
363 |
+
# trainer.test(model, dataloaders=dl_test)
|
364 |
+
if train:
|
365 |
+
trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
|
366 |
+
trainer.test(model, dataloaders=dl_test)
|
367 |
+
torch.save(model.model, './model0.bin')
|
368 |
+
else:
|
369 |
+
trainer.test(model, dataloaders=dl_test)
|
370 |
+
|
main1.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import inf
|
2 |
+
# from utils import *
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import torch.utils
|
7 |
+
import torch.utils.data
|
8 |
+
# from utils import MyDataset, custom_collate
|
9 |
+
from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
|
10 |
+
import wandb
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
import einops
|
14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
|
15 |
+
|
16 |
+
np.random.seed(123)
|
17 |
+
torch.manual_seed(123)
|
18 |
+
torch.cuda.random.manual_seed(123)
|
19 |
+
|
20 |
+
import lightning as L
|
21 |
+
import utils
|
22 |
+
from torchmetrics.text.rouge import ROUGEScore
|
23 |
+
def top_p_sampling(logits, p=0.9, temperature=0.5):
|
24 |
+
|
25 |
+
# Apply temperature scaling
|
26 |
+
logits = logits / temperature
|
27 |
+
|
28 |
+
# Sort logits and get cumulative probabilities
|
29 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
30 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
31 |
+
|
32 |
+
# Create a mask for probabilities above the threshold
|
33 |
+
sorted_indices_to_remove = cumulative_probs > p
|
34 |
+
# Shift the indices to the right to keep also the smallest p
|
35 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
36 |
+
sorted_indices_to_remove[..., 0] = 0
|
37 |
+
|
38 |
+
# Scatter sorted indices to original indices with mask
|
39 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
40 |
+
logits[indices_to_remove] = float('-inf') # Set unwanted logits to -inf
|
41 |
+
|
42 |
+
# Sample from the remaining logits
|
43 |
+
probs = F.softmax(logits, dim=-1)
|
44 |
+
sampled_indices = torch.multinomial(probs, num_samples=1)
|
45 |
+
sampled_indices = sampled_indices.squeeze(1)
|
46 |
+
|
47 |
+
return sampled_indices
|
48 |
+
|
49 |
+
class PromptTuningModel(nn.Module):
|
50 |
+
def __init__(self, num_prompts=6):
|
51 |
+
super().__init__()
|
52 |
+
self.num_prompts = num_prompts
|
53 |
+
|
54 |
+
self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
|
55 |
+
# self.model.generation_config.cache_implementation = "static"
|
56 |
+
# self.model.generation_config.max_new_tokens = 256
|
57 |
+
self.model.requires_grad_(False)
|
58 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
59 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
60 |
+
self.tokenizer.add_special_tokens({'cls_token': '[START]'})
|
61 |
+
|
62 |
+
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
|
63 |
+
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
|
64 |
+
self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
|
65 |
+
|
66 |
+
|
67 |
+
self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
|
68 |
+
|
69 |
+
tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
|
70 |
+
token_embedding = self.model.transformer.wte(tmp[0])
|
71 |
+
self.token_embedding = token_embedding
|
72 |
+
for _ in range(num_prompts//3-1):
|
73 |
+
self.token_embedding = torch.cat([self.token_embedding, token_embedding])
|
74 |
+
|
75 |
+
# print(self.token_embedding.shape)
|
76 |
+
data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
|
77 |
+
self.learnable_prompt = nn.Parameter(data, requires_grad=True)
|
78 |
+
|
79 |
+
# self.model.transformer.wte.weight[self.start].requires_grad = True
|
80 |
+
|
81 |
+
# @torch.compile
|
82 |
+
def forward(self, X, y):
|
83 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
84 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
85 |
+
|
86 |
+
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
|
87 |
+
# mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
|
88 |
+
# print(mask.shape)
|
89 |
+
# labels = torch.where(y == 50257, -100, y)
|
90 |
+
# ignore = torch.ones([X.shape[0], self.num_prompts], dtype=torch.long, device=X.device)*-100
|
91 |
+
# labels = torch.cat([ignore, labels], dim=1)
|
92 |
+
out = self.model(inputs_embeds = embeddings)
|
93 |
+
# print("Out.loss:", out.loss)
|
94 |
+
logits = out.logits[:,self.num_prompts:]
|
95 |
+
return logits
|
96 |
+
|
97 |
+
def generate_new(self, X):
|
98 |
+
batch_size = X.shape[0]
|
99 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
100 |
+
embeddings = self.model.transformer.wte(X)
|
101 |
+
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(batch_size, 1, 1), embeddings], dim=1)
|
102 |
+
|
103 |
+
cnt = 0
|
104 |
+
past_key_values = None
|
105 |
+
generated_ids = torch.tensor([], dtype=torch.long, device=X.device).view(batch_size, 0) # Store all generated tokens
|
106 |
+
|
107 |
+
while cnt < 196:
|
108 |
+
|
109 |
+
out = self.model(inputs_embeds=embeddings, use_cache=True, past_key_values=past_key_values)
|
110 |
+
past_key_values = out.past_key_values
|
111 |
+
# print(cnt)
|
112 |
+
if cnt == 0:
|
113 |
+
logits = out.logits[:, self.num_prompts:]
|
114 |
+
else:
|
115 |
+
logits = out.logits
|
116 |
+
|
117 |
+
logits[:, :, 50257:] = -1e4 # Apply after slicing for correct dimensions
|
118 |
+
|
119 |
+
next_token_ids = top_p_sampling(logits[:, -1, :])
|
120 |
+
# next_token_ids will have shape (batch_size,)
|
121 |
+
print(next_token_ids.shape)
|
122 |
+
exit()
|
123 |
+
generated_ids = torch.cat([generated_ids, next_token_ids.unsqueeze(-1)], dim=-1)
|
124 |
+
|
125 |
+
embeddings = self.model.transformer.wte(next_token_ids) # Correctly obtains embeddings for current batch
|
126 |
+
|
127 |
+
|
128 |
+
cnt += 1
|
129 |
+
|
130 |
+
#Check if all sequences have reached the <end> token
|
131 |
+
if torch.all((generated_ids == self.eot.item()).any(dim=-1)): # Check each sequence independently
|
132 |
+
break
|
133 |
+
|
134 |
+
return generated_ids
|
135 |
+
def generate(self, X):
|
136 |
+
# Only bs = 1
|
137 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
138 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
139 |
+
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
|
140 |
+
|
141 |
+
cnt = 0
|
142 |
+
past_key_values = None
|
143 |
+
final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
|
144 |
+
while cnt < 196:
|
145 |
+
out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
|
146 |
+
# print(cnt, out.logits.shape)
|
147 |
+
past_key_values = out.past_key_values
|
148 |
+
if cnt == 0:
|
149 |
+
logits = out.logits[:,self.num_prompts:]
|
150 |
+
logits[:,:, 50257:] = -1e4
|
151 |
+
|
152 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
|
153 |
+
|
154 |
+
# print(output.shape)
|
155 |
+
|
156 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
157 |
+
# print(output.shape)
|
158 |
+
embeddings = self.model.transformer.wte(output)
|
159 |
+
# print(embeddings.shape)
|
160 |
+
|
161 |
+
|
162 |
+
else:
|
163 |
+
# print(logits.shape)
|
164 |
+
logits = out.logits
|
165 |
+
logits[:, :, 50257:] = -1e4
|
166 |
+
|
167 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
|
168 |
+
# print(output)
|
169 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
170 |
+
# print(final_prediction.shape, 'final')
|
171 |
+
embeddings = self.model.transformer.wte(output)
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
cnt += 1
|
176 |
+
# print(output.shape, self.eot.shape)
|
177 |
+
if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
|
178 |
+
break
|
179 |
+
|
180 |
+
return final_prediction
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
class LMModel(nn.Module):
|
186 |
+
def __init__(self, num_prompts=6):
|
187 |
+
super().__init__()
|
188 |
+
self.num_prompts = num_prompts
|
189 |
+
|
190 |
+
self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
|
191 |
+
# self.model.generation_config.cache_implementation = "static"
|
192 |
+
# self.model.generation_config.max_new_tokens = 256
|
193 |
+
self.model.requires_grad_(False)
|
194 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
195 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
196 |
+
self.tokenizer.add_special_tokens({'cls_token': '[START]'})
|
197 |
+
|
198 |
+
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
|
199 |
+
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
|
200 |
+
self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
|
201 |
+
|
202 |
+
|
203 |
+
self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
|
204 |
+
|
205 |
+
|
206 |
+
self.model.lm_head.requires_grad_(True)
|
207 |
+
|
208 |
+
# @torch.compile
|
209 |
+
def forward(self, X, y):
|
210 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
211 |
+
logits = self.model(inputs_embeds = embeddings).logits
|
212 |
+
return logits
|
213 |
+
|
214 |
+
def generate(self, X):
|
215 |
+
# Only bs = 1
|
216 |
+
# self.learnable_prompt = self.learnable_prompt.to(X.device)
|
217 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
218 |
+
# embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
|
219 |
+
|
220 |
+
cnt = 0
|
221 |
+
past_key_values = None
|
222 |
+
final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
|
223 |
+
while cnt < 196:
|
224 |
+
out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
|
225 |
+
# print(cnt, out.logits.shape)
|
226 |
+
past_key_values = out.past_key_values
|
227 |
+
if cnt == 0:
|
228 |
+
logits = out.logits[:,self.num_prompts:]
|
229 |
+
logits[:,:, 50257:] = -1e4
|
230 |
+
|
231 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
|
232 |
+
|
233 |
+
# print(output.shape)
|
234 |
+
|
235 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
236 |
+
# print(output.shape)
|
237 |
+
embeddings = self.model.transformer.wte(output)
|
238 |
+
# print(embeddings.shape)
|
239 |
+
|
240 |
+
|
241 |
+
else:
|
242 |
+
# print(logits.shape)
|
243 |
+
logits = out.logits
|
244 |
+
logits[:, :, 50257:] = -1e4
|
245 |
+
|
246 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
|
247 |
+
# print(output)
|
248 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
249 |
+
# print(final_prediction.shape, 'final')
|
250 |
+
embeddings = self.model.transformer.wte(output)
|
251 |
+
|
252 |
+
|
253 |
+
|
254 |
+
cnt += 1
|
255 |
+
# print(output.shape, self.eot.shape)
|
256 |
+
if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
|
257 |
+
break
|
258 |
+
|
259 |
+
return final_prediction
|
260 |
+
|
261 |
+
def zero_after_x(arr, x):
|
262 |
+
"""
|
263 |
+
Zeros out all elements in each row of a 2D tensor after the first occurrence of x.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
tensor: The input 2D tensor.
|
267 |
+
x: The value after which to zero out elements.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
A new tensor with elements zeroed out after x.
|
271 |
+
"""
|
272 |
+
|
273 |
+
mask = (arr == x).cumsum(dim=1) > 0 # Create a cumulative mask
|
274 |
+
result = torch.where(mask, x, arr) #zero out where mask is True
|
275 |
+
|
276 |
+
return result
|
277 |
+
|
278 |
+
class LitModelPromptTuning(L.LightningModule):
|
279 |
+
def __init__(self, model, temperature, epoch, lr=1e-4, **kwargs):
|
280 |
+
super().__init__()
|
281 |
+
self.model = model
|
282 |
+
self.lr = lr
|
283 |
+
self.model.temperature = temperature
|
284 |
+
self.epoch = epoch
|
285 |
+
self.temperature = temperature
|
286 |
+
|
287 |
+
for key, value in kwargs.items():
|
288 |
+
setattr(self, key, value)
|
289 |
+
|
290 |
+
tokenize_to_strings = lambda text: self.model.tokenizer.convert_ids_to_tokens(self.model.tokenizer(text)["input_ids"])
|
291 |
+
self.rouge = ROUGEScore(tokenizer=tokenize_to_strings)
|
292 |
+
|
293 |
+
self.save_hyperparameters(ignore=['model'])
|
294 |
+
|
295 |
+
|
296 |
+
def training_step(self, batch, batch_idx):
|
297 |
+
X, y = batch
|
298 |
+
# for i,j in zip(X[1], y[1]):
|
299 |
+
# print(i.item(),j.item())
|
300 |
+
|
301 |
+
logits = self.model(X, y)
|
302 |
+
|
303 |
+
logits[:,:, 50257:] = -1e4
|
304 |
+
# print(X.shape, y.shape, logits.shape)
|
305 |
+
loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
|
306 |
+
# print(loss)
|
307 |
+
# prob = logits.softmax(dim=-1)[0,-300:]
|
308 |
+
# target = y[0, -300:]
|
309 |
+
# print('logits',logits[0,-300:][torch.arange(target.numel()), target])
|
310 |
+
# print(prob[torch.arange(target.numel()), target])
|
311 |
+
# print(prob.argmax(dim=-1), target, X[0, -300:])
|
312 |
+
# print(self.model.tokenizer.decode(prob.argmax(dim=-1)), 'gap', self.model.tokenizer.decode(target))
|
313 |
+
# print(self.model.tokenizer.decode(X[0,-300:]))
|
314 |
+
# x = F.cross_entropy(logits[0,:-1,:].reshape(-1, logits.shape[-1]), target=y[0,:-1].reshape(-1), ignore_index=50257, reduction='none')
|
315 |
+
# print(x[-20:])
|
316 |
+
# print(self.model.pad)
|
317 |
+
|
318 |
+
# print(X[0, -300:].shape, target.shape, prob.argmax(dim=-1).shape)
|
319 |
+
# for i,j,k in zip(X[0, -300:], target, prob.argmax(dim=-1)):
|
320 |
+
# print(self.model.tokenizer.decode(i),'\tx ',self.model.tokenizer.decode(j),'\tx ',self.model.tokenizer.decode(k))
|
321 |
+
|
322 |
+
|
323 |
+
# exit()
|
324 |
+
|
325 |
+
self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
|
326 |
+
return loss
|
327 |
+
|
328 |
+
|
329 |
+
def validation_step(self, batch, batch_idx):
|
330 |
+
X, y = batch
|
331 |
+
|
332 |
+
logits = self.model(X, y)
|
333 |
+
logits[:,:, 50257:] = -1e4
|
334 |
+
# print(X.shape, y.shape, logits.shape)
|
335 |
+
loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
|
336 |
+
|
337 |
+
self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
338 |
+
return loss
|
339 |
+
|
340 |
+
def on_test_epoch_start(self, ):
|
341 |
+
self.all_text = []
|
342 |
+
self.predicted_text = []
|
343 |
+
|
344 |
+
def test_step(self, batch, batch_idx):
|
345 |
+
if batch_idx == 0:
|
346 |
+
return
|
347 |
+
X, y = batch
|
348 |
+
# print(self.model.tokenizer.batch_decode(X))
|
349 |
+
# print(X.shape)
|
350 |
+
# with torch.no_grad()
|
351 |
+
out = self.model.generate(X)
|
352 |
+
# out = zero_after_x(out, self.model.eot.to(X.device))
|
353 |
+
# print(out.shape, y.shape)
|
354 |
+
# print(out, y)
|
355 |
+
pred = self.model.tokenizer.batch_decode(out, skip_special_tokens=False)
|
356 |
+
gt = self.model.tokenizer.batch_decode(y, skip_special_tokens=False)
|
357 |
+
|
358 |
+
|
359 |
+
print(pred)
|
360 |
+
print('GAP')
|
361 |
+
print(gt)
|
362 |
+
final_score = 0
|
363 |
+
|
364 |
+
for p,g in zip(pred, gt):
|
365 |
+
score = self.rouge(p, g, )
|
366 |
+
print(score)
|
367 |
+
# exit()
|
368 |
+
|
369 |
+
|
370 |
+
self.log_dict(score, on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
371 |
+
|
372 |
+
|
373 |
+
|
374 |
+
def configure_optimizers(self):
|
375 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
|
376 |
+
return optimizer
|
377 |
+
|
378 |
+
from lightning.pytorch.loggers import WandbLogger
|
379 |
+
if __name__ == '__main__':
|
380 |
+
train = False
|
381 |
+
|
382 |
+
torch.set_float32_matmul_precision('medium')
|
383 |
+
dl_train, dl_val, dl_test = utils.import_data(bs=24, fraction=1)
|
384 |
+
# gpt_model = PromptTuningModel(num_prompts=24)
|
385 |
+
if train:
|
386 |
+
gpt_model = LMModel(num_prompts=12)
|
387 |
+
gpt_model = torch.compile(gpt_model)
|
388 |
+
else:
|
389 |
+
gpt_model = torch.load('./model1.bin')
|
390 |
+
|
391 |
+
|
392 |
+
model = LitModelPromptTuning(
|
393 |
+
model=gpt_model,
|
394 |
+
lr=1e-4,
|
395 |
+
temperature=0.9,
|
396 |
+
epoch = 5,
|
397 |
+
|
398 |
+
type_model = 'lm_head'
|
399 |
+
)
|
400 |
+
print('Training')
|
401 |
+
|
402 |
+
logger = WandbLogger(project='Anlp-3')
|
403 |
+
trainer = L.Trainer(
|
404 |
+
accelerator='gpu',
|
405 |
+
# limit_train_batches=1,
|
406 |
+
# strategy='auto',
|
407 |
+
# strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
|
408 |
+
devices=1,
|
409 |
+
default_root_dir=f'./logs/', # Tensorflow can be used to viz
|
410 |
+
num_nodes=1,
|
411 |
+
num_sanity_val_steps=1, # runs a validation step before stating training
|
412 |
+
precision='bf16-mixed', # we use half precision to reduce memory usage
|
413 |
+
max_epochs=5,
|
414 |
+
check_val_every_n_epoch=1, # run validation every epoch
|
415 |
+
log_every_n_steps=20,
|
416 |
+
logger=logger,
|
417 |
+
# detect_anomaly=True,
|
418 |
+
)
|
419 |
+
|
420 |
+
if train:
|
421 |
+
trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
|
422 |
+
trainer.test(model, dataloaders=dl_test)
|
423 |
+
torch.save(model.model, './model1.bin')
|
424 |
+
else:
|
425 |
+
trainer.test(model, dataloaders=dl_test)
|
main2.py
ADDED
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import inf
|
2 |
+
# from utils import *
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import torch.utils
|
7 |
+
import torch.utils.data
|
8 |
+
# from utils import MyDataset, custom_collate
|
9 |
+
from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
|
10 |
+
import wandb
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
import einops
|
14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
|
15 |
+
|
16 |
+
np.random.seed(123)
|
17 |
+
torch.manual_seed(123)
|
18 |
+
torch.cuda.random.manual_seed(123)
|
19 |
+
|
20 |
+
import lightning as L
|
21 |
+
import utils
|
22 |
+
from torchmetrics.text.rouge import ROUGEScore
|
23 |
+
def top_p_sampling(logits, p=0.9, temperature=0.5):
|
24 |
+
|
25 |
+
# Apply temperature scaling
|
26 |
+
logits = logits / temperature
|
27 |
+
|
28 |
+
# Sort logits and get cumulative probabilities
|
29 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
30 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
31 |
+
|
32 |
+
# Create a mask for probabilities above the threshold
|
33 |
+
sorted_indices_to_remove = cumulative_probs > p
|
34 |
+
# Shift the indices to the right to keep also the smallest p
|
35 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
36 |
+
sorted_indices_to_remove[..., 0] = 0
|
37 |
+
|
38 |
+
# Scatter sorted indices to original indices with mask
|
39 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
40 |
+
logits[indices_to_remove] = float('-inf') # Set unwanted logits to -inf
|
41 |
+
|
42 |
+
# Sample from the remaining logits
|
43 |
+
probs = F.softmax(logits, dim=-1)
|
44 |
+
sampled_indices = torch.multinomial(probs, num_samples=1)
|
45 |
+
sampled_indices = sampled_indices.squeeze(1)
|
46 |
+
|
47 |
+
return sampled_indices
|
48 |
+
|
49 |
+
class PromptTuningModel(nn.Module):
|
50 |
+
def __init__(self, num_prompts=6):
|
51 |
+
super().__init__()
|
52 |
+
self.num_prompts = num_prompts
|
53 |
+
|
54 |
+
self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
|
55 |
+
# self.model.generation_config.cache_implementation = "static"
|
56 |
+
# self.model.generation_config.max_new_tokens = 256
|
57 |
+
self.model.requires_grad_(False)
|
58 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
59 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
60 |
+
self.tokenizer.add_special_tokens({'cls_token': '[START]'})
|
61 |
+
|
62 |
+
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
|
63 |
+
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
|
64 |
+
self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
|
65 |
+
|
66 |
+
|
67 |
+
self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
|
68 |
+
|
69 |
+
tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
|
70 |
+
token_embedding = self.model.transformer.wte(tmp[0])
|
71 |
+
self.token_embedding = token_embedding
|
72 |
+
for _ in range(num_prompts//3-1):
|
73 |
+
self.token_embedding = torch.cat([self.token_embedding, token_embedding])
|
74 |
+
|
75 |
+
# print(self.token_embedding.shape)
|
76 |
+
data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
|
77 |
+
self.learnable_prompt = nn.Parameter(data, requires_grad=True)
|
78 |
+
|
79 |
+
# self.model.transformer.wte.weight[self.start].requires_grad = True
|
80 |
+
|
81 |
+
# @torch.compile
|
82 |
+
def forward(self, X, y):
|
83 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
84 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
85 |
+
|
86 |
+
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
|
87 |
+
# mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
|
88 |
+
# print(mask.shape)
|
89 |
+
# labels = torch.where(y == 50257, -100, y)
|
90 |
+
# ignore = torch.ones([X.shape[0], self.num_prompts], dtype=torch.long, device=X.device)*-100
|
91 |
+
# labels = torch.cat([ignore, labels], dim=1)
|
92 |
+
out = self.model(inputs_embeds = embeddings)
|
93 |
+
# print("Out.loss:", out.loss)
|
94 |
+
logits = out.logits[:,self.num_prompts:]
|
95 |
+
return logits
|
96 |
+
|
97 |
+
def generate_new(self, X):
|
98 |
+
batch_size = X.shape[0]
|
99 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
100 |
+
embeddings = self.model.transformer.wte(X)
|
101 |
+
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(batch_size, 1, 1), embeddings], dim=1)
|
102 |
+
|
103 |
+
cnt = 0
|
104 |
+
past_key_values = None
|
105 |
+
generated_ids = torch.tensor([], dtype=torch.long, device=X.device).view(batch_size, 0) # Store all generated tokens
|
106 |
+
|
107 |
+
while cnt < 196:
|
108 |
+
|
109 |
+
out = self.model(inputs_embeds=embeddings, use_cache=True, past_key_values=past_key_values)
|
110 |
+
past_key_values = out.past_key_values
|
111 |
+
# print(cnt)
|
112 |
+
if cnt == 0:
|
113 |
+
logits = out.logits[:, self.num_prompts:]
|
114 |
+
else:
|
115 |
+
logits = out.logits
|
116 |
+
|
117 |
+
logits[:, :, 50257:] = -1e4 # Apply after slicing for correct dimensions
|
118 |
+
|
119 |
+
next_token_ids = top_p_sampling(logits[:, -1, :])
|
120 |
+
# next_token_ids will have shape (batch_size,)
|
121 |
+
print(next_token_ids.shape)
|
122 |
+
exit()
|
123 |
+
generated_ids = torch.cat([generated_ids, next_token_ids.unsqueeze(-1)], dim=-1)
|
124 |
+
|
125 |
+
embeddings = self.model.transformer.wte(next_token_ids) # Correctly obtains embeddings for current batch
|
126 |
+
|
127 |
+
|
128 |
+
cnt += 1
|
129 |
+
|
130 |
+
#Check if all sequences have reached the <end> token
|
131 |
+
if torch.all((generated_ids == self.eot.item()).any(dim=-1)): # Check each sequence independently
|
132 |
+
break
|
133 |
+
|
134 |
+
return generated_ids
|
135 |
+
def generate(self, X):
|
136 |
+
# Only bs = 1
|
137 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
138 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
139 |
+
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
|
140 |
+
|
141 |
+
cnt = 0
|
142 |
+
past_key_values = None
|
143 |
+
final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
|
144 |
+
while cnt < 196:
|
145 |
+
out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
|
146 |
+
# print(cnt, out.logits.shape)
|
147 |
+
past_key_values = out.past_key_values
|
148 |
+
if cnt == 0:
|
149 |
+
logits = out.logits[:,self.num_prompts:]
|
150 |
+
logits[:,:, 50257:] = -1e4
|
151 |
+
|
152 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
|
153 |
+
|
154 |
+
# print(output.shape)
|
155 |
+
|
156 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
157 |
+
# print(output.shape)
|
158 |
+
embeddings = self.model.transformer.wte(output)
|
159 |
+
# print(embeddings.shape)
|
160 |
+
|
161 |
+
|
162 |
+
else:
|
163 |
+
# print(logits.shape)
|
164 |
+
logits = out.logits
|
165 |
+
logits[:, :, 50257:] = -1e4
|
166 |
+
|
167 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
|
168 |
+
# print(output)
|
169 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
170 |
+
# print(final_prediction.shape, 'final')
|
171 |
+
embeddings = self.model.transformer.wte(output)
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
cnt += 1
|
176 |
+
# print(output.shape, self.eot.shape)
|
177 |
+
if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
|
178 |
+
break
|
179 |
+
|
180 |
+
return final_prediction
|
181 |
+
from peft import PeftModel, LoraConfig, get_peft_model
|
182 |
+
class LoraModel(nn.Module):
|
183 |
+
def __init__(self, dim=8):
|
184 |
+
super().__init__()
|
185 |
+
self.num_prompts = 0
|
186 |
+
self.dim = dim
|
187 |
+
|
188 |
+
self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
|
189 |
+
# self.model.generation_config.cache_implementation = "static"
|
190 |
+
# self.model.generation_config.max_new_tokens = 256
|
191 |
+
self.model.requires_grad_(False)
|
192 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
193 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
194 |
+
self.tokenizer.add_special_tokens({'cls_token': '[START]'})
|
195 |
+
|
196 |
+
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
|
197 |
+
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
|
198 |
+
self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
|
199 |
+
|
200 |
+
|
201 |
+
self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
|
202 |
+
|
203 |
+
lora_config = LoraConfig(
|
204 |
+
r=dim, # Rank of the low-rank matrices (adjust as needed)
|
205 |
+
lora_alpha=32, # Scaling factor (adjust as needed)
|
206 |
+
target_modules=["c_attn"], # Or other target modules as needed
|
207 |
+
lora_dropout=0.05, # Dropout probability for LoRA layers
|
208 |
+
bias="none", #Bias type for LoRA. Use "none" for no bias.
|
209 |
+
task_type="CAUSAL_LM" #Specify task type.
|
210 |
+
)
|
211 |
+
self.model = get_peft_model(self.model, lora_config)
|
212 |
+
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
+
# @torch.compile
|
218 |
+
def forward(self, X, y):
|
219 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
220 |
+
logits = self.model(inputs_embeds = embeddings).logits
|
221 |
+
return logits
|
222 |
+
|
223 |
+
def generate(self, X):
|
224 |
+
# Only bs = 1
|
225 |
+
# self.learnable_prompt = self.learnable_prompt.to(X.device)
|
226 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
227 |
+
# embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
|
228 |
+
|
229 |
+
cnt = 0
|
230 |
+
past_key_values = None
|
231 |
+
final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
|
232 |
+
while cnt < 196:
|
233 |
+
out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
|
234 |
+
# print(cnt, out.logits.shape)
|
235 |
+
past_key_values = out.past_key_values
|
236 |
+
if cnt == 0:
|
237 |
+
logits = out.logits[:,self.num_prompts:]
|
238 |
+
logits[:,:, 50257:] = -1e4
|
239 |
+
|
240 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
|
241 |
+
|
242 |
+
# print(output.shape)
|
243 |
+
|
244 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
245 |
+
# print(output.shape)
|
246 |
+
embeddings = self.model.transformer.wte(output)
|
247 |
+
# print(embeddings.shape)
|
248 |
+
|
249 |
+
|
250 |
+
else:
|
251 |
+
# print(logits.shape)
|
252 |
+
logits = out.logits
|
253 |
+
logits[:, :, 50257:] = -1e4
|
254 |
+
|
255 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
|
256 |
+
# print(output)
|
257 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
258 |
+
# print(final_prediction.shape, 'final')
|
259 |
+
embeddings = self.model.transformer.wte(output)
|
260 |
+
|
261 |
+
|
262 |
+
|
263 |
+
cnt += 1
|
264 |
+
# print(output.shape, self.eot.shape)
|
265 |
+
if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
|
266 |
+
break
|
267 |
+
|
268 |
+
return final_prediction
|
269 |
+
|
270 |
+
|
271 |
+
class LMModel(nn.Module):
|
272 |
+
def __init__(self, num_prompts=6):
|
273 |
+
super().__init__()
|
274 |
+
self.num_prompts = num_prompts
|
275 |
+
|
276 |
+
self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
|
277 |
+
# self.model.generation_config.cache_implementation = "static"
|
278 |
+
# self.model.generation_config.max_new_tokens = 256
|
279 |
+
self.model.requires_grad_(False)
|
280 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
281 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
282 |
+
self.tokenizer.add_special_tokens({'cls_token': '[START]'})
|
283 |
+
|
284 |
+
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
|
285 |
+
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
|
286 |
+
self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
|
287 |
+
|
288 |
+
|
289 |
+
self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
|
290 |
+
|
291 |
+
|
292 |
+
self.model.lm_head.requires_grad_(True)
|
293 |
+
|
294 |
+
# @torch.compile
|
295 |
+
def forward(self, X, y):
|
296 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
297 |
+
logits = self.model(inputs_embeds = embeddings).logits
|
298 |
+
return logits
|
299 |
+
|
300 |
+
def generate(self, X):
|
301 |
+
# Only bs = 1
|
302 |
+
# self.learnable_prompt = self.learnable_prompt.to(X.device)
|
303 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
304 |
+
# embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
|
305 |
+
|
306 |
+
cnt = 0
|
307 |
+
past_key_values = None
|
308 |
+
final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
|
309 |
+
while cnt < 196:
|
310 |
+
out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
|
311 |
+
# print(cnt, out.logits.shape)
|
312 |
+
past_key_values = out.past_key_values
|
313 |
+
if cnt == 0:
|
314 |
+
logits = out.logits[:,self.num_prompts:]
|
315 |
+
logits[:,:, 50257:] = -1e4
|
316 |
+
|
317 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
|
318 |
+
|
319 |
+
# print(output.shape)
|
320 |
+
|
321 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
322 |
+
# print(output.shape)
|
323 |
+
embeddings = self.model.transformer.wte(output)
|
324 |
+
# print(embeddings.shape)
|
325 |
+
|
326 |
+
|
327 |
+
else:
|
328 |
+
# print(logits.shape)
|
329 |
+
logits = out.logits
|
330 |
+
logits[:, :, 50257:] = -1e4
|
331 |
+
|
332 |
+
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
|
333 |
+
# print(output)
|
334 |
+
final_prediction = torch.cat([final_prediction, output], dim=1)
|
335 |
+
# print(final_prediction.shape, 'final')
|
336 |
+
embeddings = self.model.transformer.wte(output)
|
337 |
+
|
338 |
+
|
339 |
+
|
340 |
+
cnt += 1
|
341 |
+
# print(output.shape, self.eot.shape)
|
342 |
+
if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
|
343 |
+
break
|
344 |
+
|
345 |
+
return final_prediction
|
346 |
+
|
347 |
+
def zero_after_x(arr, x):
|
348 |
+
"""
|
349 |
+
Zeros out all elements in each row of a 2D tensor after the first occurrence of x.
|
350 |
+
|
351 |
+
Args:
|
352 |
+
tensor: The input 2D tensor.
|
353 |
+
x: The value after which to zero out elements.
|
354 |
+
|
355 |
+
Returns:
|
356 |
+
A new tensor with elements zeroed out after x.
|
357 |
+
"""
|
358 |
+
|
359 |
+
mask = (arr == x).cumsum(dim=1) > 0 # Create a cumulative mask
|
360 |
+
result = torch.where(mask, x, arr) #zero out where mask is True
|
361 |
+
|
362 |
+
return result
|
363 |
+
|
364 |
+
class LitModelPromptTuning(L.LightningModule):
|
365 |
+
def __init__(self, model, temperature, epoch, lr=1e-4, **kwargs):
|
366 |
+
super().__init__()
|
367 |
+
self.model = model
|
368 |
+
self.lr = lr
|
369 |
+
self.model.temperature = temperature
|
370 |
+
self.epoch = epoch
|
371 |
+
self.temperature = temperature
|
372 |
+
|
373 |
+
for key, value in kwargs.items():
|
374 |
+
setattr(self, key, value)
|
375 |
+
|
376 |
+
tokenize_to_strings = lambda text: self.model.tokenizer.convert_ids_to_tokens(self.model.tokenizer(text)["input_ids"])
|
377 |
+
self.rouge = ROUGEScore(tokenizer=tokenize_to_strings)
|
378 |
+
|
379 |
+
self.save_hyperparameters(ignore=['model'])
|
380 |
+
|
381 |
+
|
382 |
+
def training_step(self, batch, batch_idx):
|
383 |
+
X, y = batch
|
384 |
+
# for i,j in zip(X[1], y[1]):
|
385 |
+
# print(i.item(),j.item())
|
386 |
+
|
387 |
+
logits = self.model(X, y)
|
388 |
+
|
389 |
+
logits[:,:, 50257:] = -1e4
|
390 |
+
# print(X.shape, y.shape, logits.shape)
|
391 |
+
loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
|
392 |
+
# print(loss)
|
393 |
+
# prob = logits.softmax(dim=-1)[0,-300:]
|
394 |
+
# target = y[0, -300:]
|
395 |
+
# print('logits',logits[0,-300:][torch.arange(target.numel()), target])
|
396 |
+
# print(prob[torch.arange(target.numel()), target])
|
397 |
+
# print(prob.argmax(dim=-1), target, X[0, -300:])
|
398 |
+
# print(self.model.tokenizer.decode(prob.argmax(dim=-1)), 'gap', self.model.tokenizer.decode(target))
|
399 |
+
# print(self.model.tokenizer.decode(X[0,-300:]))
|
400 |
+
# x = F.cross_entropy(logits[0,:-1,:].reshape(-1, logits.shape[-1]), target=y[0,:-1].reshape(-1), ignore_index=50257, reduction='none')
|
401 |
+
# print(x[-20:])
|
402 |
+
# print(self.model.pad)
|
403 |
+
|
404 |
+
# print(X[0, -300:].shape, target.shape, prob.argmax(dim=-1).shape)
|
405 |
+
# for i,j,k in zip(X[0, -300:], target, prob.argmax(dim=-1)):
|
406 |
+
# print(self.model.tokenizer.decode(i),'\tx ',self.model.tokenizer.decode(j),'\tx ',self.model.tokenizer.decode(k))
|
407 |
+
|
408 |
+
|
409 |
+
# exit()
|
410 |
+
|
411 |
+
self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
|
412 |
+
return loss
|
413 |
+
|
414 |
+
|
415 |
+
def validation_step(self, batch, batch_idx):
|
416 |
+
X, y = batch
|
417 |
+
|
418 |
+
logits = self.model(X, y)
|
419 |
+
logits[:,:, 50257:] = -1e4
|
420 |
+
# print(X.shape, y.shape, logits.shape)
|
421 |
+
loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
|
422 |
+
|
423 |
+
self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
424 |
+
return loss
|
425 |
+
|
426 |
+
def on_test_epoch_start(self, ):
|
427 |
+
self.all_text = []
|
428 |
+
self.predicted_text = []
|
429 |
+
|
430 |
+
def test_step(self, batch, batch_idx):
|
431 |
+
if batch_idx == 0:
|
432 |
+
return
|
433 |
+
X, y = batch
|
434 |
+
# print(self.model.tokenizer.batch_decode(X))
|
435 |
+
# print(X.shape)
|
436 |
+
# with torch.no_grad()
|
437 |
+
out = self.model.generate(X)
|
438 |
+
# out = zero_after_x(out, self.model.eot.to(X.device))
|
439 |
+
# print(out.shape, y.shape)
|
440 |
+
# print(out, y)
|
441 |
+
pred = self.model.tokenizer.batch_decode(out, skip_special_tokens=False)
|
442 |
+
gt = self.model.tokenizer.batch_decode(y, skip_special_tokens=False)
|
443 |
+
|
444 |
+
|
445 |
+
print(pred)
|
446 |
+
print('GAP')
|
447 |
+
print(gt)
|
448 |
+
final_score = 0
|
449 |
+
|
450 |
+
for p,g in zip(pred, gt):
|
451 |
+
score = self.rouge(p, g, )
|
452 |
+
print(score)
|
453 |
+
# exit()
|
454 |
+
|
455 |
+
|
456 |
+
self.log_dict(score, on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
457 |
+
|
458 |
+
|
459 |
+
|
460 |
+
def configure_optimizers(self):
|
461 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
|
462 |
+
return optimizer
|
463 |
+
|
464 |
+
from lightning.pytorch.loggers import WandbLogger
|
465 |
+
if __name__ == '__main__':
|
466 |
+
train = False
|
467 |
+
|
468 |
+
torch.set_float32_matmul_precision('medium')
|
469 |
+
dl_train, dl_val, dl_test = utils.import_data(bs=24, fraction=1)
|
470 |
+
# gpt_model = PromptTuningModel(num_prompts=24)
|
471 |
+
if train:
|
472 |
+
gpt_model = LoraModel(dim=16)
|
473 |
+
else:
|
474 |
+
gpt_model = torch.load('./model1.bin')
|
475 |
+
|
476 |
+
# gpt_model = torch.compile(gpt_model)
|
477 |
+
model = LitModelPromptTuning(
|
478 |
+
model=gpt_model,
|
479 |
+
lr=1e-3,
|
480 |
+
temperature=0.9,
|
481 |
+
epoch = 5,
|
482 |
+
|
483 |
+
type_model = 'lora',
|
484 |
+
dimension = 16
|
485 |
+
)
|
486 |
+
print('Training')
|
487 |
+
|
488 |
+
logger = WandbLogger(project='Anlp-3')
|
489 |
+
trainer = L.Trainer(
|
490 |
+
accelerator='gpu',
|
491 |
+
# limit_train_batches=1,
|
492 |
+
# strategy='auto',
|
493 |
+
# strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
|
494 |
+
devices=1,
|
495 |
+
default_root_dir=f'./logs/', # Tensorflow can be used to viz
|
496 |
+
num_nodes=1,
|
497 |
+
num_sanity_val_steps=1, # runs a validation step before stating training
|
498 |
+
precision='bf16-mixed', # we use half precision to reduce memory usage
|
499 |
+
max_epochs=5,
|
500 |
+
check_val_every_n_epoch=1, # run validation every epoch
|
501 |
+
log_every_n_steps=20,
|
502 |
+
logger=logger,
|
503 |
+
# detect_anomaly=True,
|
504 |
+
)
|
505 |
+
|
506 |
+
if train:
|
507 |
+
trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
|
508 |
+
trainer.test(model, dataloaders=dl_test)
|
509 |
+
torch.save(model.model, './model2.bin')
|
510 |
+
else:
|
511 |
+
trainer.test(model, dataloaders=dl_test)
|
model0.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5411f279c6809268001997e4d5ea65eb050828c703d82a21066eb5f2370f01e3
|
3 |
+
size 512094683
|
model1.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:38409a7f9c3348446f062d5b616f603770976c94e0c789f7baebf28005c87674
|
3 |
+
size 511946721
|
model2.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e74693de89775360e55273ca457bfc44c1396070076530fb272bf5a355aa2a2c
|
3 |
+
size 514335129
|
newspaper-text-summarization-cnn-dailymail.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f9e0cad39333d1c9f8902be6d846000ef1ccd5e994ad3c42f53336270ab8611
|
3 |
+
size 527738644
|
utils.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import inf
|
2 |
+
# from utils import *
|
3 |
+
from h11 import Data
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
import torch.utils
|
8 |
+
import torch.utils.data
|
9 |
+
|
10 |
+
from torch.utils.data import DataLoader, Dataset
|
11 |
+
# from utils import MyDataset, custom_collate
|
12 |
+
from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
|
13 |
+
import wandb
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
import einops
|
17 |
+
import pandas as pd
|
18 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
19 |
+
from transformers import GPT2TokenizerFast
|
20 |
+
import os
|
21 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
22 |
+
import re
|
23 |
+
|
24 |
+
class CNNDataset(Dataset):
|
25 |
+
def __init__(self, df, max_length = 1000, max_len=21000, test_ds=False):
|
26 |
+
super().__init__()
|
27 |
+
self.df = df
|
28 |
+
self.max_len = max_len
|
29 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
30 |
+
self.max_length = max_length
|
31 |
+
self.test_ds = test_ds
|
32 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
33 |
+
self.tokenizer.add_special_tokens({'cls_token': '[START]'})
|
34 |
+
|
35 |
+
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
|
36 |
+
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
|
37 |
+
self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
|
38 |
+
print(len(self.tokenizer), self.start, "Pad")
|
39 |
+
for index in range(max_len):
|
40 |
+
x, y = self.df['article'][index], self.df['highlights'][index]
|
41 |
+
x, y = re.sub(r'[\t\n\r]', ' ', x) , re.sub(r'[\t\n\r]', ' ', y)
|
42 |
+
y = self.tokenizer(y, return_tensors="pt", max_length=256,truncation=True).input_ids[0]
|
43 |
+
x = self.tokenizer(x, return_tensors="pt", max_length=self.max_length-max(y.shape[0], 256+24), truncation=True).input_ids[0]
|
44 |
+
self.df.loc[index, 'article'], self.df.loc[index, 'highlights'] = x,y
|
45 |
+
|
46 |
+
def __len__(self, ):
|
47 |
+
return self.max_len
|
48 |
+
|
49 |
+
def __getitem__(self, index):
|
50 |
+
x, y = self.df['article'][index], self.df['highlights'][index]
|
51 |
+
|
52 |
+
# Check if middle self.eot is needed
|
53 |
+
# print(x, self.eot)
|
54 |
+
if self.test_ds:
|
55 |
+
return torch.cat([self.eot, x, self.start]), torch.cat([y, self.eot])
|
56 |
+
x = torch.cat([self.eot, x, self.start, y, self.eot])
|
57 |
+
y = torch.cat([y, self.eot])
|
58 |
+
|
59 |
+
y_final = torch.ones(x.shape[0], dtype=torch.long)
|
60 |
+
y_final[-y.shape[0]-1:-1] = y
|
61 |
+
y_final[:-y.shape[0]-1] = self.pad
|
62 |
+
return x, y_final
|
63 |
+
|
64 |
+
def properly_pad(context):
|
65 |
+
lenghts = []
|
66 |
+
# print(context)
|
67 |
+
for i in context:
|
68 |
+
lenghts.append(i.shape[0])
|
69 |
+
lenghts = torch.tensor(lenghts)
|
70 |
+
|
71 |
+
ind = torch.argsort(lenghts, descending=True)
|
72 |
+
lenghts = lenghts[ind]
|
73 |
+
|
74 |
+
sorted_tensors = [context[i] for i in ind]
|
75 |
+
|
76 |
+
context = sorted_tensors
|
77 |
+
context = pad_sequence(sequences=context, batch_first=True, padding_value=50257)
|
78 |
+
|
79 |
+
return context
|
80 |
+
|
81 |
+
def custom_collate(batch):
|
82 |
+
# print(batch)
|
83 |
+
context, target = [], []
|
84 |
+
# print(batch)
|
85 |
+
for a,b in batch:
|
86 |
+
context.append(a)
|
87 |
+
target.append(b)
|
88 |
+
|
89 |
+
context, target = properly_pad(context), properly_pad(target)
|
90 |
+
|
91 |
+
return context, target
|
92 |
+
|
93 |
+
def import_data(bs=4, fraction=0.1):
|
94 |
+
df_train = pd.read_csv('./cnn_dailymail/train.csv')
|
95 |
+
df_val = pd.read_csv('./cnn_dailymail/validation.csv')
|
96 |
+
df_test = pd.read_csv('./cnn_dailymail/test.csv')
|
97 |
+
|
98 |
+
print('Loaded data')
|
99 |
+
|
100 |
+
df_train, df_val, df_test = CNNDataset(df_train, max_len=int(21000*fraction)), CNNDataset(df_val, max_len=int(fraction*6000)), CNNDataset(df_test, max_len=int(fraction*300), test_ds=True)
|
101 |
+
|
102 |
+
df_train = DataLoader(df_train, batch_size=bs, num_workers=7, collate_fn=custom_collate)
|
103 |
+
df_test = DataLoader(df_test, batch_size=1, num_workers=7, collate_fn=custom_collate)
|
104 |
+
df_val = DataLoader(df_val, batch_size=bs, num_workers=7, collate_fn=custom_collate)
|
105 |
+
|
106 |
+
# print(df_train['article'][0])
|
107 |
+
return df_train, df_val, df_test
|
108 |
+
|
109 |
+
if __name__ == '__main__':
|
110 |
+
tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
111 |
+
|
112 |
+
tokenizer.add_special_tokens({'cls_token': '[START]'})
|
113 |
+
|
114 |
+
eot =tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
|
115 |
+
pad =tokenizer("[PAD]", return_tensors="pt").input_ids[0]
|
116 |
+
start =tokenizer("[START]", return_tensors="pt").input_ids[0]
|
117 |
+
|
118 |
+
print(tokenizer.decode([1, 2, 50256]))
|
119 |
+
print(tokenizer.decode([1, 2, 50257]))
|
120 |
+
print(tokenizer('[START]'))
|
121 |
+
# dl_train, dl_val, dl_test = import_data()
|
122 |
+
# for x,y in dl_train:
|
123 |
+
# print(x.shape, y.shape)
|
124 |
+
# break
|
wandb/debug-internal.log
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"time":"2024-10-29T19:35:07.322742046+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
|
2 |
+
{"time":"2024-10-29T19:35:07.322754356+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241029_193507-ujpie5pz/logs/debug-core.log"}
|
3 |
+
{"time":"2024-10-29T19:35:07.323074973+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
|
4 |
+
{"time":"2024-10-29T19:35:07.323078902+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241029_193507-ujpie5pz/logs/debug-core.log"}
|
5 |
+
{"time":"2024-10-29T19:35:07.327151125+04:00","level":"INFO","msg":"created new stream","id":"ujpie5pz"}
|
6 |
+
{"time":"2024-10-29T19:35:07.327170485+04:00","level":"INFO","msg":"stream: started","id":"ujpie5pz"}
|
7 |
+
{"time":"2024-10-29T19:35:07.327189255+04:00","level":"INFO","msg":"handler: started","stream_id":{"value":"ujpie5pz"}}
|
8 |
+
{"time":"2024-10-29T19:35:07.327206894+04:00","level":"INFO","msg":"sender: started","stream_id":{"value":"ujpie5pz"}}
|
9 |
+
{"time":"2024-10-29T19:35:07.327237104+04:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"ujpie5pz"}}
|
10 |
+
{"time":"2024-10-29T19:35:07.98503491+04:00","level":"INFO","msg":"wandb-core","!BADKEY":null}
|
11 |
+
{"time":"2024-10-29T19:35:07.98717852+04:00","level":"INFO","msg":"Starting system monitor"}
|
12 |
+
{"time":"2024-10-29T19:35:07.988807925+04:00","level":"ERROR","msg":"git repo not found","error":"repository does not exist"}
|
13 |
+
{"time":"2024-10-29T19:37:46.533074708+04:00","level":"INFO","msg":"stream: closing","id":"ujpie5pz"}
|
14 |
+
{"time":"2024-10-29T19:37:46.533114998+04:00","level":"INFO","msg":"Stopping system monitor"}
|
15 |
+
{"time":"2024-10-29T19:37:46.534237927+04:00","level":"INFO","msg":"Stopped system monitor"}
|
16 |
+
{"time":"2024-10-29T19:37:46.799687608+04:00","level":"WARN","msg":"No job ingredients found, not creating job artifact"}
|
17 |
+
{"time":"2024-10-29T19:37:46.799694428+04:00","level":"WARN","msg":"No source type found, not creating job artifact"}
|
18 |
+
{"time":"2024-10-29T19:37:46.799698199+04:00","level":"INFO","msg":"sender: sendDefer: no job artifact to save"}
|
19 |
+
{"time":"2024-10-29T19:37:48.809218785+04:00","level":"INFO","msg":"handler: closed","stream_id":{"value":"ujpie5pz"}}
|
20 |
+
{"time":"2024-10-29T19:37:48.809270345+04:00","level":"INFO","msg":"sender: closed","stream_id":{"value":"ujpie5pz"}}
|
21 |
+
{"time":"2024-10-29T19:37:48.809270115+04:00","level":"INFO","msg":"writer: Close: closed","stream_id":{"value":"ujpie5pz"}}
|
22 |
+
{"time":"2024-10-29T19:37:48.809871399+04:00","level":"INFO","msg":"stream: closed","id":"ujpie5pz"}
|
wandb/debug.log
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Current SDK version is 0.18.1
|
2 |
+
2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Configure stats pid to 1599827
|
3 |
+
2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/.config/wandb/settings
|
4 |
+
2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/Kyrylo/nlp/wandb/settings
|
5 |
+
2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Loading settings from environment variables: {}
|
6 |
+
2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Applying setup settings: {'mode': None, '_disable_service': None}
|
7 |
+
2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Inferring run settings from compute environment: {'program_relpath': 'main2.py', 'program_abspath': '/home/siddharth.tourani/Kyrylo/nlp/main2.py', 'program': '/home/siddharth.tourani/Kyrylo/nlp/main2.py'}
|
8 |
+
2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_setup.py:_flush():77] Applying login settings: {}
|
9 |
+
2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_init.py:_log_setup():532] Logging user logs to ./wandb/run-20241029_193507-ujpie5pz/logs/debug.log
|
10 |
+
2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_init.py:_log_setup():533] Logging internal logs to ./wandb/run-20241029_193507-ujpie5pz/logs/debug-internal.log
|
11 |
+
2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_init.py:init():616] calling init triggers
|
12 |
+
2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_init.py:init():623] wandb.init called with sweep_config: {}
|
13 |
+
config: {}
|
14 |
+
2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_init.py:init():666] starting backend
|
15 |
+
2024-10-29 19:35:07,318 INFO MainThread:1599827 [wandb_init.py:init():670] setting up manager
|
16 |
+
2024-10-29 19:35:07,319 INFO MainThread:1599827 [backend.py:_multiprocessing_setup():105] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
|
17 |
+
2024-10-29 19:35:07,320 INFO MainThread:1599827 [wandb_init.py:init():678] backend started and connected
|
18 |
+
2024-10-29 19:35:07,322 INFO MainThread:1599827 [wandb_init.py:init():773] updated telemetry
|
19 |
+
2024-10-29 19:35:07,323 INFO MainThread:1599827 [wandb_init.py:init():806] communicating run to backend with 90.0 second timeout
|
20 |
+
2024-10-29 19:35:07,981 INFO MainThread:1599827 [wandb_init.py:init():857] starting run threads in backend
|
21 |
+
2024-10-29 19:35:08,154 INFO MainThread:1599827 [wandb_run.py:_console_start():2459] atexit reg
|
22 |
+
2024-10-29 19:35:08,154 INFO MainThread:1599827 [wandb_run.py:_redirect():2307] redirect: wrap_raw
|
23 |
+
2024-10-29 19:35:08,154 INFO MainThread:1599827 [wandb_run.py:_redirect():2372] Wrapping output streams.
|
24 |
+
2024-10-29 19:35:08,154 INFO MainThread:1599827 [wandb_run.py:_redirect():2397] Redirects installed.
|
25 |
+
2024-10-29 19:35:08,156 INFO MainThread:1599827 [wandb_init.py:init():900] run started, returning control to user process
|
26 |
+
2024-10-29 19:35:08,447 INFO MainThread:1599827 [wandb_run.py:_config_callback():1388] config_cb None None {'temperature': 0.9, 'epoch': 5, 'lr': 0.001, 'type_model': 'lora', 'dimension': 16}
|
27 |
+
2024-10-29 19:37:46,533 WARNING MsgRouterThr:1599827 [router.py:message_loop():77] message_loop has been closed
|
wandb/run-20241028_085547-mga40p7t/files/code/main.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import inf
|
2 |
+
# from utils import *
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import torch.utils
|
7 |
+
import torch.utils.data
|
8 |
+
# from utils import MyDataset, custom_collate
|
9 |
+
from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
|
10 |
+
import wandb
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
import einops
|
14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
|
15 |
+
|
16 |
+
np.random.seed(123)
|
17 |
+
torch.manual_seed(123)
|
18 |
+
torch.cuda.random.manual_seed(123)
|
19 |
+
|
20 |
+
import lightning as L
|
21 |
+
import utils
|
22 |
+
|
23 |
+
|
24 |
+
class PromptTuningModel(nn.Module):
|
25 |
+
def __init__(self, num_prompts=6):
|
26 |
+
super().__init__()
|
27 |
+
self.num_prompts = num_prompts
|
28 |
+
|
29 |
+
self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
|
30 |
+
self.model.requires_grad_(False)
|
31 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
32 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
33 |
+
|
34 |
+
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
|
35 |
+
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
|
36 |
+
|
37 |
+
self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer), pad_to_multiple_of=128)
|
38 |
+
|
39 |
+
tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
|
40 |
+
token_embedding = self.model.transformer.wte(tmp[0])
|
41 |
+
self.token_embedding = token_embedding
|
42 |
+
for _ in range(num_prompts//3-1):
|
43 |
+
self.token_embedding = torch.cat([self.token_embedding, token_embedding])
|
44 |
+
|
45 |
+
# print(self.token_embedding.shape)
|
46 |
+
data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
|
47 |
+
self.learnable_prompt = nn.Parameter(data, requires_grad=True)
|
48 |
+
|
49 |
+
# @torch.compile
|
50 |
+
def forward(self, X):
|
51 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
52 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
53 |
+
|
54 |
+
embeddings = torch.cat([embeddings, self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1)], dim=1)
|
55 |
+
mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
|
56 |
+
# print(mask.shape)
|
57 |
+
logits = self.model(inputs_embeds = embeddings, attention_mask=mask).logits[:,self.num_prompts:].swapaxes(1,2)
|
58 |
+
return logits
|
59 |
+
|
60 |
+
class LitModelPromptTuning(L.LightningModule):
|
61 |
+
def __init__(self, model, lr=1e-4):
|
62 |
+
super().__init__()
|
63 |
+
self.model = model
|
64 |
+
self.lr = lr
|
65 |
+
|
66 |
+
self.save_hyperparameters(ignore=['model'])
|
67 |
+
|
68 |
+
|
69 |
+
def training_step(self, batch, batch_idx):
|
70 |
+
X, y = batch
|
71 |
+
# for i,j in zip(X[0], y[0]):
|
72 |
+
# print(i.item(),j.item())
|
73 |
+
# print(self.model.pad, self.model.eot)
|
74 |
+
logits = self.model(X)
|
75 |
+
loss = F.cross_entropy(logits, target=y, ignore_index=50257)
|
76 |
+
# print(loss)
|
77 |
+
# exit()
|
78 |
+
|
79 |
+
self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
|
80 |
+
return loss
|
81 |
+
|
82 |
+
|
83 |
+
def validation_step(self, batch, batch_idx):
|
84 |
+
X, y = batch
|
85 |
+
|
86 |
+
logits = self.model(X)
|
87 |
+
loss = F.cross_entropy(logits, target=y, ignore_index=50257)
|
88 |
+
|
89 |
+
self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
90 |
+
return loss
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
def configure_optimizers(self):
|
95 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
|
96 |
+
return optimizer
|
97 |
+
|
98 |
+
from lightning.pytorch.loggers import WandbLogger
|
99 |
+
if __name__ == '__main__':
|
100 |
+
torch.set_float32_matmul_precision('medium')
|
101 |
+
dl_train, dl_val, dl_test = utils.import_data(bs=5, fraction=0.1)
|
102 |
+
gpt_model = PromptTuningModel()
|
103 |
+
# gpt_model = torch.compile(gpt_model)
|
104 |
+
model = LitModelPromptTuning(model=gpt_model)
|
105 |
+
print('Training')
|
106 |
+
|
107 |
+
logger = WandbLogger(project='Anlp-3')
|
108 |
+
trainer = L.Trainer(
|
109 |
+
accelerator='gpu',
|
110 |
+
# strategy='auto',
|
111 |
+
# strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
|
112 |
+
devices=[3],
|
113 |
+
default_root_dir=f'./logs/', # Tensorflow can be used to viz
|
114 |
+
num_nodes=1,
|
115 |
+
num_sanity_val_steps=1, # runs a validation step before stating training
|
116 |
+
precision='16-mixed', # we use half precision to reduce memory usage
|
117 |
+
max_epochs=10,
|
118 |
+
check_val_every_n_epoch=1, # run validation every epoch
|
119 |
+
log_every_n_steps=20,
|
120 |
+
logger=logger
|
121 |
+
# detect_anomaly=True,
|
122 |
+
)
|
123 |
+
|
124 |
+
trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
|
wandb/run-20241028_085547-mga40p7t/files/config.yaml
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_wandb:
|
2 |
+
value:
|
3 |
+
cli_version: 0.18.1
|
4 |
+
code_path: code/main.py
|
5 |
+
m:
|
6 |
+
- "1": trainer/global_step
|
7 |
+
"6":
|
8 |
+
- 3
|
9 |
+
"7": []
|
10 |
+
- "1": Training loss_step
|
11 |
+
"5": 1
|
12 |
+
"6":
|
13 |
+
- 1
|
14 |
+
- 3
|
15 |
+
"7": []
|
16 |
+
- "1": epoch
|
17 |
+
"5": 1
|
18 |
+
"6":
|
19 |
+
- 1
|
20 |
+
- 3
|
21 |
+
"7": []
|
22 |
+
- "1": Validation loss_step
|
23 |
+
"5": 1
|
24 |
+
"6":
|
25 |
+
- 1
|
26 |
+
- 3
|
27 |
+
"7": []
|
28 |
+
- "1": Validation loss_epoch
|
29 |
+
"5": 1
|
30 |
+
"6":
|
31 |
+
- 1
|
32 |
+
- 3
|
33 |
+
"7": []
|
34 |
+
- "1": Training loss_epoch
|
35 |
+
"5": 1
|
36 |
+
"6":
|
37 |
+
- 1
|
38 |
+
- 3
|
39 |
+
"7": []
|
40 |
+
python_version: 3.10.13
|
41 |
+
t:
|
42 |
+
"1":
|
43 |
+
- 1
|
44 |
+
- 11
|
45 |
+
- 49
|
46 |
+
- 55
|
47 |
+
- 71
|
48 |
+
- 106
|
49 |
+
"2":
|
50 |
+
- 1
|
51 |
+
- 11
|
52 |
+
- 49
|
53 |
+
- 55
|
54 |
+
- 71
|
55 |
+
- 106
|
56 |
+
"3":
|
57 |
+
- 7
|
58 |
+
- 23
|
59 |
+
- 55
|
60 |
+
- 66
|
61 |
+
"4": 3.10.13
|
62 |
+
"5": 0.18.1
|
63 |
+
"6": 4.44.2
|
64 |
+
"8":
|
65 |
+
- 5
|
66 |
+
"12": 0.18.1
|
67 |
+
"13": linux-x86_64
|
68 |
+
lr:
|
69 |
+
value: 0.0001
|
wandb/run-20241028_085547-mga40p7t/files/output.log
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
|
2 |
+
|
3 |
+
| Name | Type | Params | Mode
|
4 |
+
----------------------------------------------------
|
5 |
+
0 | model | PromptTuningModel | 124 M | train
|
6 |
+
----------------------------------------------------
|
7 |
+
4.6 K Trainable params
|
8 |
+
124 M Non-trainable params
|
9 |
+
124 M Total params
|
10 |
+
497.922 Total estimated model params size (MB)
|
11 |
+
1 Modules in train mode
|
12 |
+
164 Modules in eval mode
|
13 |
+
Epoch 1: 3%|██▌ | 11/420 [00:02<01:47, 3.79it/s, v_num=0p7t]
|
14 |
+
|
15 |
+
|
16 |
+
Detected KeyboardInterrupt, attempting graceful shutdown ...
|
wandb/run-20241028_085547-mga40p7t/files/requirements.txt
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
jiter==0.5.0
|
2 |
+
anyio==4.6.0
|
3 |
+
interegular==0.3.3
|
4 |
+
jaxlib==0.4.34
|
5 |
+
jsonschema==4.23.0
|
6 |
+
typing_extensions==4.12.2
|
7 |
+
httpcore==1.0.5
|
8 |
+
prometheus_client==0.21.0
|
9 |
+
openai==1.51.0
|
10 |
+
multidict==6.1.0
|
11 |
+
six==1.16.0
|
12 |
+
nvidia-nccl-cu12==2.20.5
|
13 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
14 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
15 |
+
nvidia-cudnn-cu12==9.1.0.70
|
16 |
+
watchfiles==0.24.0
|
17 |
+
tqdm==4.66.5
|
18 |
+
yarl==1.11.1
|
19 |
+
cffi==1.17.1
|
20 |
+
vllm==0.6.1.post2
|
21 |
+
bleach==6.1.0
|
22 |
+
kaggle==1.6.17
|
23 |
+
pydantic_core==2.23.4
|
24 |
+
lightning-utilities==0.11.7
|
25 |
+
sentry-sdk==2.14.0
|
26 |
+
torch==2.4.0
|
27 |
+
aiohappyeyeballs==2.4.0
|
28 |
+
diffusers==0.15.0
|
29 |
+
GitPython==3.1.43
|
30 |
+
attrs==24.2.0
|
31 |
+
importlib_metadata==8.5.0
|
32 |
+
transformers==4.44.2
|
33 |
+
pillow==10.4.0
|
34 |
+
sounddevice==0.5.1
|
35 |
+
gguf==0.9.1
|
36 |
+
python-dotenv==1.0.1
|
37 |
+
async-timeout==4.0.3
|
38 |
+
dspy-ai==2.5.3
|
39 |
+
numpy==1.26.4
|
40 |
+
nvidia-nvjitlink-cu12==12.6.68
|
41 |
+
uvicorn==0.30.6
|
42 |
+
kiwisolver==1.4.7
|
43 |
+
partial-json-parser==0.2.1.1.post4
|
44 |
+
pyparsing==3.1.4
|
45 |
+
lightning==2.4.0
|
46 |
+
structlog==24.4.0
|
47 |
+
nvidia-curand-cu12==10.3.2.106
|
48 |
+
setuptools==65.5.0
|
49 |
+
webencodings==0.5.1
|
50 |
+
nvidia-nvtx-cu12==12.1.105
|
51 |
+
sniffio==1.3.1
|
52 |
+
MarkupSafe==2.1.5
|
53 |
+
vllm-flash-attn==2.6.1
|
54 |
+
urllib3==2.2.3
|
55 |
+
requests==2.32.3
|
56 |
+
pycountry==24.6.1
|
57 |
+
ujson==5.10.0
|
58 |
+
matplotlib==3.9.2
|
59 |
+
pydantic==2.9.2
|
60 |
+
torchvision==0.19.0
|
61 |
+
numba==0.60.0
|
62 |
+
optuna==4.0.0
|
63 |
+
opt_einsum==3.4.0
|
64 |
+
joblib==1.4.2
|
65 |
+
msgpack==1.1.0
|
66 |
+
smmap==5.0.1
|
67 |
+
filelock==3.16.1
|
68 |
+
opencv-contrib-python==4.10.0.84
|
69 |
+
faiss-gpu==1.7.2
|
70 |
+
prometheus-fastapi-instrumentator==7.0.0
|
71 |
+
rpds-py==0.20.0
|
72 |
+
psutil==6.0.0
|
73 |
+
colorlog==6.8.2
|
74 |
+
nvidia-cufft-cu12==11.0.2.54
|
75 |
+
SQLAlchemy==2.0.35
|
76 |
+
llvmlite==0.43.0
|
77 |
+
packaging==24.1
|
78 |
+
exceptiongroup==1.2.2
|
79 |
+
dill==0.3.8
|
80 |
+
ml_dtypes==0.5.0
|
81 |
+
pyairports==2.1.1
|
82 |
+
scikit-learn==1.5.2
|
83 |
+
prettytable==3.11.0
|
84 |
+
protobuf==4.25.5
|
85 |
+
charset-normalizer==3.3.2
|
86 |
+
torchmetrics==1.4.2
|
87 |
+
text-unidecode==1.3
|
88 |
+
httpx==0.27.2
|
89 |
+
sympy==1.13.3
|
90 |
+
msgspec==0.18.6
|
91 |
+
wandb==0.18.1
|
92 |
+
backoff==2.2.1
|
93 |
+
sentencepiece==0.2.0
|
94 |
+
aiohttp==3.10.5
|
95 |
+
distro==1.9.0
|
96 |
+
lark==1.2.2
|
97 |
+
pyarrow==17.0.0
|
98 |
+
Mako==1.3.5
|
99 |
+
regex==2024.9.11
|
100 |
+
safetensors==0.4.5
|
101 |
+
aiosignal==1.3.1
|
102 |
+
jsonschema-specifications==2023.12.1
|
103 |
+
cloudpickle==3.0.0
|
104 |
+
einops==0.8.0
|
105 |
+
ray==2.36.1
|
106 |
+
fire==0.7.0
|
107 |
+
pyzmq==26.2.0
|
108 |
+
pycparser==2.22
|
109 |
+
platformdirs==4.3.6
|
110 |
+
click==8.1.7
|
111 |
+
fastapi==0.115.0
|
112 |
+
ftfy==6.3.0
|
113 |
+
torchtext==0.18.0
|
114 |
+
lm-format-enforcer==0.10.6
|
115 |
+
fsspec==2024.6.1
|
116 |
+
tzdata==2024.2
|
117 |
+
starlette==0.38.6
|
118 |
+
cycler==0.12.1
|
119 |
+
py-cpuinfo==9.0.0
|
120 |
+
h11==0.14.0
|
121 |
+
huggingface-hub==0.25.1
|
122 |
+
nvidia-cusparse-cu12==12.1.0.106
|
123 |
+
nvidia-ml-py==12.560.30
|
124 |
+
certifi==2024.8.30
|
125 |
+
httptools==0.6.1
|
126 |
+
jax==0.4.34
|
127 |
+
PyYAML==6.0.2
|
128 |
+
xxhash==3.5.0
|
129 |
+
idna==3.10
|
130 |
+
xformers==0.0.27.post2
|
131 |
+
mistral_common==1.4.3
|
132 |
+
fonttools==4.54.0
|
133 |
+
pip==23.0.1
|
134 |
+
accelerate==0.34.2
|
135 |
+
mediapipe==0.10.15
|
136 |
+
pytorch-lightning==2.4.0
|
137 |
+
ollama==0.3.3
|
138 |
+
Jinja2==3.1.4
|
139 |
+
multiprocess==0.70.16
|
140 |
+
opencv-python==4.10.0.84
|
141 |
+
termcolor==2.5.0
|
142 |
+
python-dateutil==2.9.0.post0
|
143 |
+
contourpy==1.3.0
|
144 |
+
websockets==13.1
|
145 |
+
frozenlist==1.4.1
|
146 |
+
pandas==2.2.3
|
147 |
+
networkx==3.3
|
148 |
+
diskcache==5.6.3
|
149 |
+
nvidia-cusolver-cu12==11.4.5.107
|
150 |
+
flatbuffers==24.3.25
|
151 |
+
mpmath==1.3.0
|
152 |
+
setproctitle==1.3.3
|
153 |
+
tokenizers==0.19.1
|
154 |
+
scipy==1.14.1
|
155 |
+
outlines==0.0.46
|
156 |
+
annotated-types==0.7.0
|
157 |
+
docker-pycreds==0.4.0
|
158 |
+
magicattr==0.1.6
|
159 |
+
wcwidth==0.2.13
|
160 |
+
pytorch-metric-learning==2.6.0
|
161 |
+
datasets==3.0.0
|
162 |
+
gitdb==4.0.11
|
163 |
+
lora-diffusion==0.1.7
|
164 |
+
referencing==0.35.1
|
165 |
+
python-slugify==8.0.4
|
166 |
+
zipp==3.20.2
|
167 |
+
triton==3.0.0
|
168 |
+
absl-py==2.1.0
|
169 |
+
threadpoolctl==3.5.0
|
170 |
+
uvloop==0.20.0
|
171 |
+
tiktoken==0.7.0
|
172 |
+
pytz==2024.2
|
173 |
+
nest-asyncio==1.6.0
|
174 |
+
nvidia-cublas-cu12==12.1.3.1
|
175 |
+
litellm==1.48.12
|
176 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
177 |
+
greenlet==3.1.1
|
178 |
+
alembic==1.13.3
|
wandb/run-20241028_085547-mga40p7t/files/wandb-metadata.json
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"os": "Linux-5.15.161-ql-generic-13.0-14-x86_64-with-glibc2.35",
|
3 |
+
"python": "3.10.13",
|
4 |
+
"startedAt": "2024-10-28T04:55:47.033649Z",
|
5 |
+
"program": "/home/siddharth.tourani/Kyrylo/nlp/main.py",
|
6 |
+
"codePath": "main.py",
|
7 |
+
"email": "kyrylo.shyvam@students.iiit.ac.in",
|
8 |
+
"root": ".",
|
9 |
+
"host": "gpu-08",
|
10 |
+
"username": "siddharth.tourani",
|
11 |
+
"executable": "/home/siddharth.tourani/Minimal/bin/python3",
|
12 |
+
"codePathLocal": "main.py",
|
13 |
+
"cpu_count": 128,
|
14 |
+
"cpu_count_logical": 256,
|
15 |
+
"gpu": "[NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB]",
|
16 |
+
"gpu_count": 4,
|
17 |
+
"disk": {
|
18 |
+
"/": {
|
19 |
+
"total": "1073741824",
|
20 |
+
"used": "21049344"
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"memory": {
|
24 |
+
"total": "270256893952"
|
25 |
+
},
|
26 |
+
"cpu": {
|
27 |
+
"count": 128,
|
28 |
+
"countLogical": 256
|
29 |
+
},
|
30 |
+
"gpu_nvidia": [
|
31 |
+
{
|
32 |
+
"name": "NVIDIA A100-SXM4-40GB",
|
33 |
+
"memoryTotal": "42949672960",
|
34 |
+
"cudaCores": 6912,
|
35 |
+
"architecture": "Ampere"
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"name": "NVIDIA A100-SXM4-40GB",
|
39 |
+
"memoryTotal": "42949672960",
|
40 |
+
"cudaCores": 6912,
|
41 |
+
"architecture": "Ampere"
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"name": "NVIDIA A100-SXM4-40GB",
|
45 |
+
"memoryTotal": "42949672960",
|
46 |
+
"cudaCores": 6912,
|
47 |
+
"architecture": "Ampere"
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"name": "NVIDIA A100-SXM4-40GB",
|
51 |
+
"memoryTotal": "42949672960",
|
52 |
+
"cudaCores": 6912,
|
53 |
+
"architecture": "Ampere"
|
54 |
+
}
|
55 |
+
],
|
56 |
+
"slurm": {
|
57 |
+
"job_id": "68153"
|
58 |
+
},
|
59 |
+
"cudaVersion": "12.4"
|
60 |
+
}
|
wandb/run-20241028_085547-mga40p7t/files/wandb-summary.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"epoch":0,"Training loss_epoch":107.80526733398438,"_step":142,"_runtime":102.656591521,"Validation loss_epoch":107.9881820678711,"Training loss_step":108.02771759033203,"_timestamp":1.730091449689934e+09,"trainer/global_step":419,"Validation loss_step":108.47615051269531,"_wandb":{"runtime":105}}
|
wandb/run-20241028_085547-mga40p7t/logs/debug-core.log
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"time":"2024-10-28T08:55:46.356024757+04:00","level":"INFO","msg":"started logging, with flags","port-filename":"/tmp/tmptnuwuzf_/port-1239188.txt","pid":1239188,"debug":false,"disable-analytics":false}
|
2 |
+
{"time":"2024-10-28T08:55:46.356050947+04:00","level":"INFO","msg":"FeatureState","shutdownOnParentExitEnabled":false}
|
3 |
+
{"time":"2024-10-28T08:55:46.35809675+04:00","level":"INFO","msg":"Will exit if parent process dies.","ppid":1239188}
|
4 |
+
{"time":"2024-10-28T08:55:46.35809731+04:00","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":36381,"Zone":""}}
|
5 |
+
{"time":"2024-10-28T08:55:46.482841457+04:00","level":"INFO","msg":"created new connection","id":"127.0.0.1:59774"}
|
6 |
+
{"time":"2024-10-28T08:55:47.034493884+04:00","level":"INFO","msg":"connection init received","streamId":"mga40p7t","id":"127.0.0.1:59774"}
|
7 |
+
{"time":"2024-10-28T08:55:47.035343787+04:00","level":"ERROR","msg":"error creating symlink","error":"symlink /home/siddharth.tourani/.cache/wandb/logs/core-debug-20241028_085546.log wandb/run-20241028_085547-mga40p7t/logs/debug-core.log: file exists"}
|
8 |
+
{"time":"2024-10-28T08:55:47.044292544+04:00","level":"INFO","msg":"connection init completed","streamId":"mga40p7t","id":"127.0.0.1:59774"}
|
9 |
+
{"time":"2024-10-28T08:57:32.990052624+04:00","level":"INFO","msg":"connection: teardown","id":"127.0.0.1:59774"}
|
10 |
+
{"time":"2024-10-28T08:57:32.990356191+04:00","level":"INFO","msg":"server is shutting down"}
|
11 |
+
{"time":"2024-10-28T08:57:32.990407101+04:00","level":"INFO","msg":"closed connection","id":"127.0.0.1:59774"}
|
12 |
+
{"time":"2024-10-28T08:57:33.264727032+04:00","level":"ERROR","msg":"error flushing writer","err":"write tcp 127.0.0.1:36381->127.0.0.1:59774: use of closed network connection","id":"127.0.0.1:59774"}
|
13 |
+
{"time":"2024-10-28T08:57:34.392709437+04:00","level":"INFO","msg":"connection closed","id":"127.0.0.1:59774"}
|
14 |
+
{"time":"2024-10-28T08:57:34.392722266+04:00","level":"INFO","msg":"server is closed"}
|
wandb/run-20241028_085547-mga40p7t/logs/debug-internal.log
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"time":"2024-10-28T08:55:47.035188748+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
|
2 |
+
{"time":"2024-10-28T08:55:47.035200088+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241028_085547-mga40p7t/logs/debug-core.log"}
|
3 |
+
{"time":"2024-10-28T08:55:47.035571516+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
|
4 |
+
{"time":"2024-10-28T08:55:47.035578935+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241028_085547-mga40p7t/logs/debug-core.log"}
|
5 |
+
{"time":"2024-10-28T08:55:47.044272764+04:00","level":"INFO","msg":"created new stream","id":"mga40p7t"}
|
6 |
+
{"time":"2024-10-28T08:55:47.044289244+04:00","level":"INFO","msg":"stream: started","id":"mga40p7t"}
|
7 |
+
{"time":"2024-10-28T08:55:47.044306184+04:00","level":"INFO","msg":"handler: started","stream_id":{"value":"mga40p7t"}}
|
8 |
+
{"time":"2024-10-28T08:55:47.044329513+04:00","level":"INFO","msg":"sender: started","stream_id":{"value":"mga40p7t"}}
|
9 |
+
{"time":"2024-10-28T08:55:47.044370663+04:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"mga40p7t"}}
|
10 |
+
{"time":"2024-10-28T08:55:47.744002406+04:00","level":"INFO","msg":"wandb-core","!BADKEY":null}
|
11 |
+
{"time":"2024-10-28T08:55:47.747442758+04:00","level":"INFO","msg":"Starting system monitor"}
|
12 |
+
{"time":"2024-10-28T08:55:47.748927455+04:00","level":"ERROR","msg":"git repo not found","error":"repository does not exist"}
|
13 |
+
{"time":"2024-10-28T08:57:32.990331552+04:00","level":"INFO","msg":"stream: closing","id":"mga40p7t"}
|
14 |
+
{"time":"2024-10-28T08:57:32.990407231+04:00","level":"INFO","msg":"Stopping system monitor"}
|
15 |
+
{"time":"2024-10-28T08:57:32.992264416+04:00","level":"INFO","msg":"Stopped system monitor"}
|
16 |
+
{"time":"2024-10-28T08:57:33.519506024+04:00","level":"WARN","msg":"No job ingredients found, not creating job artifact"}
|
17 |
+
{"time":"2024-10-28T08:57:33.519518634+04:00","level":"WARN","msg":"No source type found, not creating job artifact"}
|
18 |
+
{"time":"2024-10-28T08:57:33.519521734+04:00","level":"INFO","msg":"sender: sendDefer: no job artifact to save"}
|
19 |
+
{"time":"2024-10-28T08:57:34.391853104+04:00","level":"INFO","msg":"handler: closed","stream_id":{"value":"mga40p7t"}}
|
20 |
+
{"time":"2024-10-28T08:57:34.391906283+04:00","level":"INFO","msg":"writer: Close: closed","stream_id":{"value":"mga40p7t"}}
|
21 |
+
{"time":"2024-10-28T08:57:34.391911083+04:00","level":"INFO","msg":"sender: closed","stream_id":{"value":"mga40p7t"}}
|
22 |
+
{"time":"2024-10-28T08:57:34.392637247+04:00","level":"INFO","msg":"stream: closed","id":"mga40p7t"}
|
wandb/run-20241028_085547-mga40p7t/logs/debug.log
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Current SDK version is 0.18.1
|
2 |
+
2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Configure stats pid to 1239188
|
3 |
+
2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/.config/wandb/settings
|
4 |
+
2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/Kyrylo/nlp/wandb/settings
|
5 |
+
2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Loading settings from environment variables: {}
|
6 |
+
2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Applying setup settings: {'mode': None, '_disable_service': None}
|
7 |
+
2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Inferring run settings from compute environment: {'program_relpath': 'main.py', 'program_abspath': '/home/siddharth.tourani/Kyrylo/nlp/main.py', 'program': '/home/siddharth.tourani/Kyrylo/nlp/main.py'}
|
8 |
+
2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_setup.py:_flush():77] Applying login settings: {}
|
9 |
+
2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_init.py:_log_setup():532] Logging user logs to ./wandb/run-20241028_085547-mga40p7t/logs/debug.log
|
10 |
+
2024-10-28 08:55:47,029 INFO MainThread:1239188 [wandb_init.py:_log_setup():533] Logging internal logs to ./wandb/run-20241028_085547-mga40p7t/logs/debug-internal.log
|
11 |
+
2024-10-28 08:55:47,030 INFO MainThread:1239188 [wandb_init.py:init():616] calling init triggers
|
12 |
+
2024-10-28 08:55:47,030 INFO MainThread:1239188 [wandb_init.py:init():623] wandb.init called with sweep_config: {}
|
13 |
+
config: {}
|
14 |
+
2024-10-28 08:55:47,030 INFO MainThread:1239188 [wandb_init.py:init():666] starting backend
|
15 |
+
2024-10-28 08:55:47,030 INFO MainThread:1239188 [wandb_init.py:init():670] setting up manager
|
16 |
+
2024-10-28 08:55:47,030 INFO MainThread:1239188 [backend.py:_multiprocessing_setup():105] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
|
17 |
+
2024-10-28 08:55:47,033 INFO MainThread:1239188 [wandb_init.py:init():678] backend started and connected
|
18 |
+
2024-10-28 08:55:47,036 INFO MainThread:1239188 [wandb_init.py:init():773] updated telemetry
|
19 |
+
2024-10-28 08:55:47,036 INFO MainThread:1239188 [wandb_init.py:init():806] communicating run to backend with 90.0 second timeout
|
20 |
+
2024-10-28 08:55:47,740 INFO MainThread:1239188 [wandb_init.py:init():857] starting run threads in backend
|
21 |
+
2024-10-28 08:55:47,953 INFO MainThread:1239188 [wandb_run.py:_console_start():2459] atexit reg
|
22 |
+
2024-10-28 08:55:47,953 INFO MainThread:1239188 [wandb_run.py:_redirect():2307] redirect: wrap_raw
|
23 |
+
2024-10-28 08:55:47,953 INFO MainThread:1239188 [wandb_run.py:_redirect():2372] Wrapping output streams.
|
24 |
+
2024-10-28 08:55:47,953 INFO MainThread:1239188 [wandb_run.py:_redirect():2397] Redirects installed.
|
25 |
+
2024-10-28 08:55:47,957 INFO MainThread:1239188 [wandb_init.py:init():900] run started, returning control to user process
|
26 |
+
2024-10-28 08:55:48,061 INFO MainThread:1239188 [wandb_run.py:_config_callback():1388] config_cb None None {'lr': 0.0001}
|
27 |
+
2024-10-28 08:57:32,990 WARNING MsgRouterThr:1239188 [router.py:message_loop():77] message_loop has been closed
|
wandb/run-20241028_085547-mga40p7t/run-mga40p7t.wandb
ADDED
Binary file (500 kB). View file
|
|
wandb/run-20241028_085806-owcrwbil/files/code/main.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import inf
|
2 |
+
# from utils import *
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import torch.utils
|
7 |
+
import torch.utils.data
|
8 |
+
# from utils import MyDataset, custom_collate
|
9 |
+
from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
|
10 |
+
import wandb
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
import einops
|
14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
|
15 |
+
|
16 |
+
np.random.seed(123)
|
17 |
+
torch.manual_seed(123)
|
18 |
+
torch.cuda.random.manual_seed(123)
|
19 |
+
|
20 |
+
import lightning as L
|
21 |
+
import utils
|
22 |
+
|
23 |
+
|
24 |
+
class PromptTuningModel(nn.Module):
|
25 |
+
def __init__(self, num_prompts=6):
|
26 |
+
super().__init__()
|
27 |
+
self.num_prompts = num_prompts
|
28 |
+
|
29 |
+
self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
|
30 |
+
self.model.requires_grad_(False)
|
31 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
32 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
33 |
+
|
34 |
+
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
|
35 |
+
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
|
36 |
+
|
37 |
+
self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer), pad_to_multiple_of=128)
|
38 |
+
|
39 |
+
tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
|
40 |
+
token_embedding = self.model.transformer.wte(tmp[0])
|
41 |
+
self.token_embedding = token_embedding
|
42 |
+
for _ in range(num_prompts//3-1):
|
43 |
+
self.token_embedding = torch.cat([self.token_embedding, token_embedding])
|
44 |
+
|
45 |
+
# print(self.token_embedding.shape)
|
46 |
+
data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
|
47 |
+
self.learnable_prompt = nn.Parameter(data, requires_grad=True)
|
48 |
+
|
49 |
+
# @torch.compile
|
50 |
+
def forward(self, X):
|
51 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
52 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
53 |
+
|
54 |
+
embeddings = torch.cat([embeddings, self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1)], dim=1)
|
55 |
+
mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
|
56 |
+
# print(mask.shape)
|
57 |
+
logits = self.model(inputs_embeds = embeddings, attention_mask=mask).logits[:,self.num_prompts:].swapaxes(1,2)
|
58 |
+
return logits
|
59 |
+
|
60 |
+
class LitModelPromptTuning(L.LightningModule):
|
61 |
+
def __init__(self, model, lr=1e-4):
|
62 |
+
super().__init__()
|
63 |
+
self.model = model
|
64 |
+
self.lr = lr
|
65 |
+
|
66 |
+
self.save_hyperparameters(ignore=['model'])
|
67 |
+
|
68 |
+
|
69 |
+
def training_step(self, batch, batch_idx):
|
70 |
+
X, y = batch
|
71 |
+
# for i,j in zip(X[0], y[0]):
|
72 |
+
# print(i.item(),j.item())
|
73 |
+
# print(self.model.pad, self.model.eot)
|
74 |
+
logits = self.model(X)
|
75 |
+
loss = F.cross_entropy(logits, target=y, ignore_index=50257)
|
76 |
+
# print(loss)
|
77 |
+
# exit()
|
78 |
+
|
79 |
+
self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
|
80 |
+
return loss
|
81 |
+
|
82 |
+
|
83 |
+
def validation_step(self, batch, batch_idx):
|
84 |
+
X, y = batch
|
85 |
+
|
86 |
+
logits = self.model(X)
|
87 |
+
loss = F.cross_entropy(logits, target=y, ignore_index=50257)
|
88 |
+
|
89 |
+
self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
90 |
+
return loss
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
def configure_optimizers(self):
|
95 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
|
96 |
+
return optimizer
|
97 |
+
|
98 |
+
from lightning.pytorch.loggers import WandbLogger
|
99 |
+
if __name__ == '__main__':
|
100 |
+
torch.set_float32_matmul_precision('medium')
|
101 |
+
dl_train, dl_val, dl_test = utils.import_data(bs=20, fraction=0.1)
|
102 |
+
gpt_model = PromptTuningModel()
|
103 |
+
# gpt_model = torch.compile(gpt_model)
|
104 |
+
model = LitModelPromptTuning(model=gpt_model)
|
105 |
+
print('Training')
|
106 |
+
|
107 |
+
logger = WandbLogger(project='Anlp-3')
|
108 |
+
trainer = L.Trainer(
|
109 |
+
accelerator='gpu',
|
110 |
+
# strategy='auto',
|
111 |
+
# strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
|
112 |
+
devices=[3],
|
113 |
+
default_root_dir=f'./logs/', # Tensorflow can be used to viz
|
114 |
+
num_nodes=1,
|
115 |
+
num_sanity_val_steps=1, # runs a validation step before stating training
|
116 |
+
precision='16-mixed', # we use half precision to reduce memory usage
|
117 |
+
max_epochs=10,
|
118 |
+
check_val_every_n_epoch=1, # run validation every epoch
|
119 |
+
log_every_n_steps=20,
|
120 |
+
logger=logger
|
121 |
+
# detect_anomaly=True,
|
122 |
+
)
|
123 |
+
|
124 |
+
trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
|
wandb/run-20241028_085806-owcrwbil/files/config.yaml
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_wandb:
|
2 |
+
value:
|
3 |
+
cli_version: 0.18.1
|
4 |
+
code_path: code/main.py
|
5 |
+
m:
|
6 |
+
- "1": epoch
|
7 |
+
"5": 2
|
8 |
+
"6":
|
9 |
+
- 1
|
10 |
+
- 3
|
11 |
+
"7": []
|
12 |
+
- "1": trainer/global_step
|
13 |
+
"6":
|
14 |
+
- 3
|
15 |
+
"7": []
|
16 |
+
- "1": Training loss_step
|
17 |
+
"5": 2
|
18 |
+
"6":
|
19 |
+
- 1
|
20 |
+
- 3
|
21 |
+
"7": []
|
22 |
+
- "1": Validation loss_step
|
23 |
+
"5": 2
|
24 |
+
"6":
|
25 |
+
- 1
|
26 |
+
- 3
|
27 |
+
"7": []
|
28 |
+
- "1": Validation loss_epoch
|
29 |
+
"5": 2
|
30 |
+
"6":
|
31 |
+
- 1
|
32 |
+
- 3
|
33 |
+
"7": []
|
34 |
+
- "1": Training loss_epoch
|
35 |
+
"5": 2
|
36 |
+
"6":
|
37 |
+
- 1
|
38 |
+
- 3
|
39 |
+
"7": []
|
40 |
+
python_version: 3.10.13
|
41 |
+
t:
|
42 |
+
"1":
|
43 |
+
- 1
|
44 |
+
- 11
|
45 |
+
- 49
|
46 |
+
- 55
|
47 |
+
- 71
|
48 |
+
- 106
|
49 |
+
"2":
|
50 |
+
- 1
|
51 |
+
- 11
|
52 |
+
- 49
|
53 |
+
- 55
|
54 |
+
- 71
|
55 |
+
- 106
|
56 |
+
"3":
|
57 |
+
- 7
|
58 |
+
- 23
|
59 |
+
- 55
|
60 |
+
- 66
|
61 |
+
"4": 3.10.13
|
62 |
+
"5": 0.18.1
|
63 |
+
"6": 4.44.2
|
64 |
+
"8":
|
65 |
+
- 5
|
66 |
+
"12": 0.18.1
|
67 |
+
"13": linux-x86_64
|
68 |
+
lr:
|
69 |
+
value: 0.0001
|
wandb/run-20241028_085806-owcrwbil/files/output.log
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
|
2 |
+
|
3 |
+
| Name | Type | Params | Mode
|
4 |
+
----------------------------------------------------
|
5 |
+
0 | model | PromptTuningModel | 124 M | train
|
6 |
+
----------------------------------------------------
|
7 |
+
4.6 K Trainable params
|
8 |
+
124 M Non-trainable params
|
9 |
+
124 M Total params
|
10 |
+
497.922 Total estimated model params size (MB)
|
11 |
+
1 Modules in train mode
|
12 |
+
164 Modules in eval mode
|
13 |
+
Epoch 1: 56%|███████████████████████████████████████████████████████▋ | 59/105 [00:24<00:19, 2.36it/s, v_num=wbil]
|
14 |
+
|
15 |
+
|
16 |
+
Detected KeyboardInterrupt, attempting graceful shutdown ...
|
wandb/run-20241028_085806-owcrwbil/files/requirements.txt
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
jiter==0.5.0
|
2 |
+
anyio==4.6.0
|
3 |
+
interegular==0.3.3
|
4 |
+
jaxlib==0.4.34
|
5 |
+
jsonschema==4.23.0
|
6 |
+
typing_extensions==4.12.2
|
7 |
+
httpcore==1.0.5
|
8 |
+
prometheus_client==0.21.0
|
9 |
+
openai==1.51.0
|
10 |
+
multidict==6.1.0
|
11 |
+
six==1.16.0
|
12 |
+
nvidia-nccl-cu12==2.20.5
|
13 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
14 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
15 |
+
nvidia-cudnn-cu12==9.1.0.70
|
16 |
+
watchfiles==0.24.0
|
17 |
+
tqdm==4.66.5
|
18 |
+
yarl==1.11.1
|
19 |
+
cffi==1.17.1
|
20 |
+
vllm==0.6.1.post2
|
21 |
+
bleach==6.1.0
|
22 |
+
kaggle==1.6.17
|
23 |
+
pydantic_core==2.23.4
|
24 |
+
lightning-utilities==0.11.7
|
25 |
+
sentry-sdk==2.14.0
|
26 |
+
torch==2.4.0
|
27 |
+
aiohappyeyeballs==2.4.0
|
28 |
+
diffusers==0.15.0
|
29 |
+
GitPython==3.1.43
|
30 |
+
attrs==24.2.0
|
31 |
+
importlib_metadata==8.5.0
|
32 |
+
transformers==4.44.2
|
33 |
+
pillow==10.4.0
|
34 |
+
sounddevice==0.5.1
|
35 |
+
gguf==0.9.1
|
36 |
+
python-dotenv==1.0.1
|
37 |
+
async-timeout==4.0.3
|
38 |
+
dspy-ai==2.5.3
|
39 |
+
numpy==1.26.4
|
40 |
+
nvidia-nvjitlink-cu12==12.6.68
|
41 |
+
uvicorn==0.30.6
|
42 |
+
kiwisolver==1.4.7
|
43 |
+
partial-json-parser==0.2.1.1.post4
|
44 |
+
pyparsing==3.1.4
|
45 |
+
lightning==2.4.0
|
46 |
+
structlog==24.4.0
|
47 |
+
nvidia-curand-cu12==10.3.2.106
|
48 |
+
setuptools==65.5.0
|
49 |
+
webencodings==0.5.1
|
50 |
+
nvidia-nvtx-cu12==12.1.105
|
51 |
+
sniffio==1.3.1
|
52 |
+
MarkupSafe==2.1.5
|
53 |
+
vllm-flash-attn==2.6.1
|
54 |
+
urllib3==2.2.3
|
55 |
+
requests==2.32.3
|
56 |
+
pycountry==24.6.1
|
57 |
+
ujson==5.10.0
|
58 |
+
matplotlib==3.9.2
|
59 |
+
pydantic==2.9.2
|
60 |
+
torchvision==0.19.0
|
61 |
+
numba==0.60.0
|
62 |
+
optuna==4.0.0
|
63 |
+
opt_einsum==3.4.0
|
64 |
+
joblib==1.4.2
|
65 |
+
msgpack==1.1.0
|
66 |
+
smmap==5.0.1
|
67 |
+
filelock==3.16.1
|
68 |
+
opencv-contrib-python==4.10.0.84
|
69 |
+
faiss-gpu==1.7.2
|
70 |
+
prometheus-fastapi-instrumentator==7.0.0
|
71 |
+
rpds-py==0.20.0
|
72 |
+
psutil==6.0.0
|
73 |
+
colorlog==6.8.2
|
74 |
+
nvidia-cufft-cu12==11.0.2.54
|
75 |
+
SQLAlchemy==2.0.35
|
76 |
+
llvmlite==0.43.0
|
77 |
+
packaging==24.1
|
78 |
+
exceptiongroup==1.2.2
|
79 |
+
dill==0.3.8
|
80 |
+
ml_dtypes==0.5.0
|
81 |
+
pyairports==2.1.1
|
82 |
+
scikit-learn==1.5.2
|
83 |
+
prettytable==3.11.0
|
84 |
+
protobuf==4.25.5
|
85 |
+
charset-normalizer==3.3.2
|
86 |
+
torchmetrics==1.4.2
|
87 |
+
text-unidecode==1.3
|
88 |
+
httpx==0.27.2
|
89 |
+
sympy==1.13.3
|
90 |
+
msgspec==0.18.6
|
91 |
+
wandb==0.18.1
|
92 |
+
backoff==2.2.1
|
93 |
+
sentencepiece==0.2.0
|
94 |
+
aiohttp==3.10.5
|
95 |
+
distro==1.9.0
|
96 |
+
lark==1.2.2
|
97 |
+
pyarrow==17.0.0
|
98 |
+
Mako==1.3.5
|
99 |
+
regex==2024.9.11
|
100 |
+
safetensors==0.4.5
|
101 |
+
aiosignal==1.3.1
|
102 |
+
jsonschema-specifications==2023.12.1
|
103 |
+
cloudpickle==3.0.0
|
104 |
+
einops==0.8.0
|
105 |
+
ray==2.36.1
|
106 |
+
fire==0.7.0
|
107 |
+
pyzmq==26.2.0
|
108 |
+
pycparser==2.22
|
109 |
+
platformdirs==4.3.6
|
110 |
+
click==8.1.7
|
111 |
+
fastapi==0.115.0
|
112 |
+
ftfy==6.3.0
|
113 |
+
torchtext==0.18.0
|
114 |
+
lm-format-enforcer==0.10.6
|
115 |
+
fsspec==2024.6.1
|
116 |
+
tzdata==2024.2
|
117 |
+
starlette==0.38.6
|
118 |
+
cycler==0.12.1
|
119 |
+
py-cpuinfo==9.0.0
|
120 |
+
h11==0.14.0
|
121 |
+
huggingface-hub==0.25.1
|
122 |
+
nvidia-cusparse-cu12==12.1.0.106
|
123 |
+
nvidia-ml-py==12.560.30
|
124 |
+
certifi==2024.8.30
|
125 |
+
httptools==0.6.1
|
126 |
+
jax==0.4.34
|
127 |
+
PyYAML==6.0.2
|
128 |
+
xxhash==3.5.0
|
129 |
+
idna==3.10
|
130 |
+
xformers==0.0.27.post2
|
131 |
+
mistral_common==1.4.3
|
132 |
+
fonttools==4.54.0
|
133 |
+
pip==23.0.1
|
134 |
+
accelerate==0.34.2
|
135 |
+
mediapipe==0.10.15
|
136 |
+
pytorch-lightning==2.4.0
|
137 |
+
ollama==0.3.3
|
138 |
+
Jinja2==3.1.4
|
139 |
+
multiprocess==0.70.16
|
140 |
+
opencv-python==4.10.0.84
|
141 |
+
termcolor==2.5.0
|
142 |
+
python-dateutil==2.9.0.post0
|
143 |
+
contourpy==1.3.0
|
144 |
+
websockets==13.1
|
145 |
+
frozenlist==1.4.1
|
146 |
+
pandas==2.2.3
|
147 |
+
networkx==3.3
|
148 |
+
diskcache==5.6.3
|
149 |
+
nvidia-cusolver-cu12==11.4.5.107
|
150 |
+
flatbuffers==24.3.25
|
151 |
+
mpmath==1.3.0
|
152 |
+
setproctitle==1.3.3
|
153 |
+
tokenizers==0.19.1
|
154 |
+
scipy==1.14.1
|
155 |
+
outlines==0.0.46
|
156 |
+
annotated-types==0.7.0
|
157 |
+
docker-pycreds==0.4.0
|
158 |
+
magicattr==0.1.6
|
159 |
+
wcwidth==0.2.13
|
160 |
+
pytorch-metric-learning==2.6.0
|
161 |
+
datasets==3.0.0
|
162 |
+
gitdb==4.0.11
|
163 |
+
lora-diffusion==0.1.7
|
164 |
+
referencing==0.35.1
|
165 |
+
python-slugify==8.0.4
|
166 |
+
zipp==3.20.2
|
167 |
+
triton==3.0.0
|
168 |
+
absl-py==2.1.0
|
169 |
+
threadpoolctl==3.5.0
|
170 |
+
uvloop==0.20.0
|
171 |
+
tiktoken==0.7.0
|
172 |
+
pytz==2024.2
|
173 |
+
nest-asyncio==1.6.0
|
174 |
+
nvidia-cublas-cu12==12.1.3.1
|
175 |
+
litellm==1.48.12
|
176 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
177 |
+
greenlet==3.1.1
|
178 |
+
alembic==1.13.3
|
wandb/run-20241028_085806-owcrwbil/files/wandb-metadata.json
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"os": "Linux-5.15.161-ql-generic-13.0-14-x86_64-with-glibc2.35",
|
3 |
+
"python": "3.10.13",
|
4 |
+
"startedAt": "2024-10-28T04:58:06.613389Z",
|
5 |
+
"program": "/home/siddharth.tourani/Kyrylo/nlp/main.py",
|
6 |
+
"codePath": "main.py",
|
7 |
+
"email": "kyrylo.shyvam@students.iiit.ac.in",
|
8 |
+
"root": ".",
|
9 |
+
"host": "gpu-08",
|
10 |
+
"username": "siddharth.tourani",
|
11 |
+
"executable": "/home/siddharth.tourani/Minimal/bin/python3",
|
12 |
+
"codePathLocal": "main.py",
|
13 |
+
"cpu_count": 128,
|
14 |
+
"cpu_count_logical": 256,
|
15 |
+
"gpu": "[NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB]",
|
16 |
+
"gpu_count": 4,
|
17 |
+
"disk": {
|
18 |
+
"/": {
|
19 |
+
"total": "1073741824",
|
20 |
+
"used": "21049344"
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"memory": {
|
24 |
+
"total": "270256893952"
|
25 |
+
},
|
26 |
+
"cpu": {
|
27 |
+
"count": 128,
|
28 |
+
"countLogical": 256
|
29 |
+
},
|
30 |
+
"gpu_nvidia": [
|
31 |
+
{
|
32 |
+
"name": "NVIDIA A100-SXM4-40GB",
|
33 |
+
"memoryTotal": "42949672960",
|
34 |
+
"cudaCores": 6912,
|
35 |
+
"architecture": "Ampere"
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"name": "NVIDIA A100-SXM4-40GB",
|
39 |
+
"memoryTotal": "42949672960",
|
40 |
+
"cudaCores": 6912,
|
41 |
+
"architecture": "Ampere"
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"name": "NVIDIA A100-SXM4-40GB",
|
45 |
+
"memoryTotal": "42949672960",
|
46 |
+
"cudaCores": 6912,
|
47 |
+
"architecture": "Ampere"
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"name": "NVIDIA A100-SXM4-40GB",
|
51 |
+
"memoryTotal": "42949672960",
|
52 |
+
"cudaCores": 6912,
|
53 |
+
"architecture": "Ampere"
|
54 |
+
}
|
55 |
+
],
|
56 |
+
"slurm": {
|
57 |
+
"job_id": "68153"
|
58 |
+
},
|
59 |
+
"cudaVersion": "12.4"
|
60 |
+
}
|
wandb/run-20241028_085806-owcrwbil/files/wandb-summary.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"_runtime":77.016630791,"Validation loss_step":107.9970703125,"Validation loss_epoch":108.69440460205078,"_timestamp":1.7300915636298473e+09,"_wandb":{"runtime":79},"epoch":1,"Training loss_step":107.52567291259766,"trainer/global_step":159,"_step":39,"Training loss_epoch":108.31751251220703}
|
wandb/run-20241028_085806-owcrwbil/logs/debug-core.log
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"time":"2024-10-28T08:58:05.852186282+04:00","level":"INFO","msg":"started logging, with flags","port-filename":"/tmp/tmpbt47w5hb/port-1241278.txt","pid":1241278,"debug":false,"disable-analytics":false}
|
2 |
+
{"time":"2024-10-28T08:58:05.852211372+04:00","level":"INFO","msg":"FeatureState","shutdownOnParentExitEnabled":false}
|
3 |
+
{"time":"2024-10-28T08:58:05.853103688+04:00","level":"INFO","msg":"Will exit if parent process dies.","ppid":1241278}
|
4 |
+
{"time":"2024-10-28T08:58:05.853096439+04:00","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":39955,"Zone":""}}
|
5 |
+
{"time":"2024-10-28T08:58:06.04806319+04:00","level":"INFO","msg":"created new connection","id":"127.0.0.1:41868"}
|
6 |
+
{"time":"2024-10-28T08:58:06.614533728+04:00","level":"INFO","msg":"connection init received","streamId":"owcrwbil","id":"127.0.0.1:41868"}
|
7 |
+
{"time":"2024-10-28T08:58:06.615450814+04:00","level":"ERROR","msg":"error creating symlink","error":"symlink /home/siddharth.tourani/.cache/wandb/logs/core-debug-20241028_085805.log wandb/run-20241028_085806-owcrwbil/logs/debug-core.log: file exists"}
|
8 |
+
{"time":"2024-10-28T08:58:06.620019897+04:00","level":"INFO","msg":"connection init completed","streamId":"owcrwbil","id":"127.0.0.1:41868"}
|
9 |
+
{"time":"2024-10-28T08:59:25.917611765+04:00","level":"INFO","msg":"connection: teardown","id":"127.0.0.1:41868"}
|
10 |
+
{"time":"2024-10-28T08:59:25.917909894+04:00","level":"INFO","msg":"server is shutting down"}
|
11 |
+
{"time":"2024-10-28T08:59:25.917970554+04:00","level":"INFO","msg":"closed connection","id":"127.0.0.1:41868"}
|
12 |
+
{"time":"2024-10-28T08:59:32.101339343+04:00","level":"INFO","msg":"connection closed","id":"127.0.0.1:41868"}
|
13 |
+
{"time":"2024-10-28T08:59:32.101350683+04:00","level":"INFO","msg":"server is closed"}
|
wandb/run-20241028_085806-owcrwbil/logs/debug-internal.log
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"time":"2024-10-28T08:58:06.615297525+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
|
2 |
+
{"time":"2024-10-28T08:58:06.615309735+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241028_085806-owcrwbil/logs/debug-core.log"}
|
3 |
+
{"time":"2024-10-28T08:58:06.615703263+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
|
4 |
+
{"time":"2024-10-28T08:58:06.615710223+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241028_085806-owcrwbil/logs/debug-core.log"}
|
5 |
+
{"time":"2024-10-28T08:58:06.619998427+04:00","level":"INFO","msg":"created new stream","id":"owcrwbil"}
|
6 |
+
{"time":"2024-10-28T08:58:06.620015787+04:00","level":"INFO","msg":"stream: started","id":"owcrwbil"}
|
7 |
+
{"time":"2024-10-28T08:58:06.620034997+04:00","level":"INFO","msg":"sender: started","stream_id":{"value":"owcrwbil"}}
|
8 |
+
{"time":"2024-10-28T08:58:06.620043877+04:00","level":"INFO","msg":"handler: started","stream_id":{"value":"owcrwbil"}}
|
9 |
+
{"time":"2024-10-28T08:58:06.620038217+04:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"owcrwbil"}}
|
10 |
+
{"time":"2024-10-28T08:58:07.277024244+04:00","level":"INFO","msg":"wandb-core","!BADKEY":null}
|
11 |
+
{"time":"2024-10-28T08:58:07.279017506+04:00","level":"INFO","msg":"Starting system monitor"}
|
12 |
+
{"time":"2024-10-28T08:58:07.28068419+04:00","level":"ERROR","msg":"git repo not found","error":"repository does not exist"}
|
13 |
+
{"time":"2024-10-28T08:59:25.917904554+04:00","level":"INFO","msg":"stream: closing","id":"owcrwbil"}
|
14 |
+
{"time":"2024-10-28T08:59:25.917952654+04:00","level":"INFO","msg":"Stopping system monitor"}
|
15 |
+
{"time":"2024-10-28T08:59:25.919306298+04:00","level":"INFO","msg":"Stopped system monitor"}
|
16 |
+
{"time":"2024-10-28T08:59:26.173567721+04:00","level":"WARN","msg":"No job ingredients found, not creating job artifact"}
|
17 |
+
{"time":"2024-10-28T08:59:26.173578511+04:00","level":"WARN","msg":"No source type found, not creating job artifact"}
|
18 |
+
{"time":"2024-10-28T08:59:26.173581711+04:00","level":"INFO","msg":"sender: sendDefer: no job artifact to save"}
|
19 |
+
{"time":"2024-10-28T08:59:32.100682276+04:00","level":"INFO","msg":"handler: closed","stream_id":{"value":"owcrwbil"}}
|
20 |
+
{"time":"2024-10-28T08:59:32.100731535+04:00","level":"INFO","msg":"writer: Close: closed","stream_id":{"value":"owcrwbil"}}
|
21 |
+
{"time":"2024-10-28T08:59:32.100765025+04:00","level":"INFO","msg":"sender: closed","stream_id":{"value":"owcrwbil"}}
|
22 |
+
{"time":"2024-10-28T08:59:32.101266313+04:00","level":"INFO","msg":"stream: closed","id":"owcrwbil"}
|
wandb/run-20241028_085806-owcrwbil/logs/debug.log
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Current SDK version is 0.18.1
|
2 |
+
2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Configure stats pid to 1241278
|
3 |
+
2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/.config/wandb/settings
|
4 |
+
2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/Kyrylo/nlp/wandb/settings
|
5 |
+
2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Loading settings from environment variables: {}
|
6 |
+
2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Applying setup settings: {'mode': None, '_disable_service': None}
|
7 |
+
2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Inferring run settings from compute environment: {'program_relpath': 'main.py', 'program_abspath': '/home/siddharth.tourani/Kyrylo/nlp/main.py', 'program': '/home/siddharth.tourani/Kyrylo/nlp/main.py'}
|
8 |
+
2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_setup.py:_flush():77] Applying login settings: {}
|
9 |
+
2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_init.py:_log_setup():532] Logging user logs to ./wandb/run-20241028_085806-owcrwbil/logs/debug.log
|
10 |
+
2024-10-28 08:58:06,610 INFO MainThread:1241278 [wandb_init.py:_log_setup():533] Logging internal logs to ./wandb/run-20241028_085806-owcrwbil/logs/debug-internal.log
|
11 |
+
2024-10-28 08:58:06,611 INFO MainThread:1241278 [wandb_init.py:init():616] calling init triggers
|
12 |
+
2024-10-28 08:58:06,611 INFO MainThread:1241278 [wandb_init.py:init():623] wandb.init called with sweep_config: {}
|
13 |
+
config: {}
|
14 |
+
2024-10-28 08:58:06,611 INFO MainThread:1241278 [wandb_init.py:init():666] starting backend
|
15 |
+
2024-10-28 08:58:06,611 INFO MainThread:1241278 [wandb_init.py:init():670] setting up manager
|
16 |
+
2024-10-28 08:58:06,611 INFO MainThread:1241278 [backend.py:_multiprocessing_setup():105] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
|
17 |
+
2024-10-28 08:58:06,613 INFO MainThread:1241278 [wandb_init.py:init():678] backend started and connected
|
18 |
+
2024-10-28 08:58:06,615 INFO MainThread:1241278 [wandb_init.py:init():773] updated telemetry
|
19 |
+
2024-10-28 08:58:06,616 INFO MainThread:1241278 [wandb_init.py:init():806] communicating run to backend with 90.0 second timeout
|
20 |
+
2024-10-28 08:58:07,272 INFO MainThread:1241278 [wandb_init.py:init():857] starting run threads in backend
|
21 |
+
2024-10-28 08:58:07,450 INFO MainThread:1241278 [wandb_run.py:_console_start():2459] atexit reg
|
22 |
+
2024-10-28 08:58:07,450 INFO MainThread:1241278 [wandb_run.py:_redirect():2307] redirect: wrap_raw
|
23 |
+
2024-10-28 08:58:07,450 INFO MainThread:1241278 [wandb_run.py:_redirect():2372] Wrapping output streams.
|
24 |
+
2024-10-28 08:58:07,450 INFO MainThread:1241278 [wandb_run.py:_redirect():2397] Redirects installed.
|
25 |
+
2024-10-28 08:58:07,452 INFO MainThread:1241278 [wandb_init.py:init():900] run started, returning control to user process
|
26 |
+
2024-10-28 08:58:07,549 INFO MainThread:1241278 [wandb_run.py:_config_callback():1388] config_cb None None {'lr': 0.0001}
|
27 |
+
2024-10-28 08:59:25,918 WARNING MsgRouterThr:1241278 [router.py:message_loop():77] message_loop has been closed
|
wandb/run-20241028_085806-owcrwbil/run-owcrwbil.wandb
ADDED
Binary file (211 kB). View file
|
|
wandb/run-20241028_090044-f9fzz8iy/files/code/main.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import inf
|
2 |
+
# from utils import *
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import torch.utils
|
7 |
+
import torch.utils.data
|
8 |
+
# from utils import MyDataset, custom_collate
|
9 |
+
from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
|
10 |
+
import wandb
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
import einops
|
14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
|
15 |
+
|
16 |
+
np.random.seed(123)
|
17 |
+
torch.manual_seed(123)
|
18 |
+
torch.cuda.random.manual_seed(123)
|
19 |
+
|
20 |
+
import lightning as L
|
21 |
+
import utils
|
22 |
+
|
23 |
+
|
24 |
+
class PromptTuningModel(nn.Module):
|
25 |
+
def __init__(self, num_prompts=6):
|
26 |
+
super().__init__()
|
27 |
+
self.num_prompts = num_prompts
|
28 |
+
|
29 |
+
self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
|
30 |
+
self.model.requires_grad_(False)
|
31 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
32 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
33 |
+
|
34 |
+
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
|
35 |
+
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
|
36 |
+
|
37 |
+
self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer), pad_to_multiple_of=128)
|
38 |
+
|
39 |
+
tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
|
40 |
+
token_embedding = self.model.transformer.wte(tmp[0])
|
41 |
+
self.token_embedding = token_embedding
|
42 |
+
for _ in range(num_prompts//3-1):
|
43 |
+
self.token_embedding = torch.cat([self.token_embedding, token_embedding])
|
44 |
+
|
45 |
+
# print(self.token_embedding.shape)
|
46 |
+
data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
|
47 |
+
self.learnable_prompt = nn.Parameter(data, requires_grad=True)
|
48 |
+
|
49 |
+
# @torch.compile
|
50 |
+
def forward(self, X):
|
51 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
52 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
53 |
+
|
54 |
+
embeddings = torch.cat([embeddings, self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1)], dim=1)
|
55 |
+
mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
|
56 |
+
# print(mask.shape)
|
57 |
+
logits = self.model(inputs_embeds = embeddings, attention_mask=mask).logits[:,self.num_prompts:].swapaxes(1,2)
|
58 |
+
return logits
|
59 |
+
|
60 |
+
class LitModelPromptTuning(L.LightningModule):
|
61 |
+
def __init__(self, model, lr=1e-4):
|
62 |
+
super().__init__()
|
63 |
+
self.model = model
|
64 |
+
self.lr = lr
|
65 |
+
|
66 |
+
self.save_hyperparameters(ignore=['model'])
|
67 |
+
|
68 |
+
|
69 |
+
def training_step(self, batch, batch_idx):
|
70 |
+
X, y = batch
|
71 |
+
# for i,j in zip(X[0], y[0]):
|
72 |
+
# print(i.item(),j.item())
|
73 |
+
# print(self.model.pad, self.model.eot)
|
74 |
+
logits = self.model(X)
|
75 |
+
loss = F.cross_entropy(logits, target=y, ignore_index=50257)
|
76 |
+
# print(loss)
|
77 |
+
# exit()
|
78 |
+
|
79 |
+
self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
|
80 |
+
return loss
|
81 |
+
|
82 |
+
|
83 |
+
def validation_step(self, batch, batch_idx):
|
84 |
+
X, y = batch
|
85 |
+
|
86 |
+
logits = self.model(X)
|
87 |
+
loss = F.cross_entropy(logits, target=y, ignore_index=50257)
|
88 |
+
|
89 |
+
self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
90 |
+
return loss
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
def configure_optimizers(self):
|
95 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
|
96 |
+
return optimizer
|
97 |
+
|
98 |
+
from lightning.pytorch.loggers import WandbLogger
|
99 |
+
if __name__ == '__main__':
|
100 |
+
torch.set_float32_matmul_precision('medium')
|
101 |
+
dl_train, dl_val, dl_test = utils.import_data(bs=25, fraction=0.1)
|
102 |
+
gpt_model = PromptTuningModel()
|
103 |
+
gpt_model = torch.compile(gpt_model)
|
104 |
+
model = LitModelPromptTuning(model=gpt_model)
|
105 |
+
print('Training')
|
106 |
+
|
107 |
+
logger = WandbLogger(project='Anlp-3')
|
108 |
+
trainer = L.Trainer(
|
109 |
+
accelerator='gpu',
|
110 |
+
# strategy='auto',
|
111 |
+
# strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
|
112 |
+
devices=[3],
|
113 |
+
default_root_dir=f'./logs/', # Tensorflow can be used to viz
|
114 |
+
num_nodes=1,
|
115 |
+
num_sanity_val_steps=1, # runs a validation step before stating training
|
116 |
+
precision='16-mixed', # we use half precision to reduce memory usage
|
117 |
+
max_epochs=10,
|
118 |
+
check_val_every_n_epoch=1, # run validation every epoch
|
119 |
+
log_every_n_steps=20,
|
120 |
+
logger=logger
|
121 |
+
# detect_anomaly=True,
|
122 |
+
)
|
123 |
+
|
124 |
+
trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
|
wandb/run-20241028_090044-f9fzz8iy/files/config.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_wandb:
|
2 |
+
value:
|
3 |
+
cli_version: 0.18.1
|
4 |
+
code_path: code/main.py
|
5 |
+
m:
|
6 |
+
- "1": trainer/global_step
|
7 |
+
"6":
|
8 |
+
- 3
|
9 |
+
"7": []
|
10 |
+
- "1": Training loss_step
|
11 |
+
"5": 1
|
12 |
+
"6":
|
13 |
+
- 1
|
14 |
+
- 3
|
15 |
+
"7": []
|
16 |
+
- "1": epoch
|
17 |
+
"5": 1
|
18 |
+
"6":
|
19 |
+
- 1
|
20 |
+
- 3
|
21 |
+
"7": []
|
22 |
+
python_version: 3.10.13
|
23 |
+
t:
|
24 |
+
"1":
|
25 |
+
- 1
|
26 |
+
- 11
|
27 |
+
- 49
|
28 |
+
- 55
|
29 |
+
- 71
|
30 |
+
- 106
|
31 |
+
"2":
|
32 |
+
- 1
|
33 |
+
- 11
|
34 |
+
- 49
|
35 |
+
- 55
|
36 |
+
- 71
|
37 |
+
- 106
|
38 |
+
"3":
|
39 |
+
- 7
|
40 |
+
- 23
|
41 |
+
- 55
|
42 |
+
- 66
|
43 |
+
"4": 3.10.13
|
44 |
+
"5": 0.18.1
|
45 |
+
"6": 4.44.2
|
46 |
+
"8":
|
47 |
+
- 5
|
48 |
+
"12": 0.18.1
|
49 |
+
"13": linux-x86_64
|
50 |
+
lr:
|
51 |
+
value: 0.0001
|
wandb/run-20241028_090044-f9fzz8iy/files/output.log
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
|
2 |
+
|
3 |
+
| Name | Type | Params | Mode
|
4 |
+
----------------------------------------------------
|
5 |
+
0 | model | PromptTuningModel | 124 M | train
|
6 |
+
----------------------------------------------------
|
7 |
+
4.6 K Trainable params
|
8 |
+
124 M Non-trainable params
|
9 |
+
124 M Total params
|
10 |
+
497.922 Total estimated model params size (MB)
|
11 |
+
1 Modules in train mode
|
12 |
+
164 Modules in eval mode
|
13 |
+
Epoch 0: 67%|██████████████████████████████████████████████████████████████████▋ | 56/84 [00:27<00:13, 2.07it/s, v_num=z8iy]
|
14 |
+
|
15 |
+
Detected KeyboardInterrupt, attempting graceful shutdown ...
|
wandb/run-20241028_090044-f9fzz8iy/files/requirements.txt
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
jiter==0.5.0
|
2 |
+
anyio==4.6.0
|
3 |
+
interegular==0.3.3
|
4 |
+
jaxlib==0.4.34
|
5 |
+
jsonschema==4.23.0
|
6 |
+
typing_extensions==4.12.2
|
7 |
+
httpcore==1.0.5
|
8 |
+
prometheus_client==0.21.0
|
9 |
+
openai==1.51.0
|
10 |
+
multidict==6.1.0
|
11 |
+
six==1.16.0
|
12 |
+
nvidia-nccl-cu12==2.20.5
|
13 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
14 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
15 |
+
nvidia-cudnn-cu12==9.1.0.70
|
16 |
+
watchfiles==0.24.0
|
17 |
+
tqdm==4.66.5
|
18 |
+
yarl==1.11.1
|
19 |
+
cffi==1.17.1
|
20 |
+
vllm==0.6.1.post2
|
21 |
+
bleach==6.1.0
|
22 |
+
kaggle==1.6.17
|
23 |
+
pydantic_core==2.23.4
|
24 |
+
lightning-utilities==0.11.7
|
25 |
+
sentry-sdk==2.14.0
|
26 |
+
torch==2.4.0
|
27 |
+
aiohappyeyeballs==2.4.0
|
28 |
+
diffusers==0.15.0
|
29 |
+
GitPython==3.1.43
|
30 |
+
attrs==24.2.0
|
31 |
+
importlib_metadata==8.5.0
|
32 |
+
transformers==4.44.2
|
33 |
+
pillow==10.4.0
|
34 |
+
sounddevice==0.5.1
|
35 |
+
gguf==0.9.1
|
36 |
+
python-dotenv==1.0.1
|
37 |
+
async-timeout==4.0.3
|
38 |
+
dspy-ai==2.5.3
|
39 |
+
numpy==1.26.4
|
40 |
+
nvidia-nvjitlink-cu12==12.6.68
|
41 |
+
uvicorn==0.30.6
|
42 |
+
kiwisolver==1.4.7
|
43 |
+
partial-json-parser==0.2.1.1.post4
|
44 |
+
pyparsing==3.1.4
|
45 |
+
lightning==2.4.0
|
46 |
+
structlog==24.4.0
|
47 |
+
nvidia-curand-cu12==10.3.2.106
|
48 |
+
setuptools==65.5.0
|
49 |
+
webencodings==0.5.1
|
50 |
+
nvidia-nvtx-cu12==12.1.105
|
51 |
+
sniffio==1.3.1
|
52 |
+
MarkupSafe==2.1.5
|
53 |
+
vllm-flash-attn==2.6.1
|
54 |
+
urllib3==2.2.3
|
55 |
+
requests==2.32.3
|
56 |
+
pycountry==24.6.1
|
57 |
+
ujson==5.10.0
|
58 |
+
matplotlib==3.9.2
|
59 |
+
pydantic==2.9.2
|
60 |
+
torchvision==0.19.0
|
61 |
+
numba==0.60.0
|
62 |
+
optuna==4.0.0
|
63 |
+
opt_einsum==3.4.0
|
64 |
+
joblib==1.4.2
|
65 |
+
msgpack==1.1.0
|
66 |
+
smmap==5.0.1
|
67 |
+
filelock==3.16.1
|
68 |
+
opencv-contrib-python==4.10.0.84
|
69 |
+
faiss-gpu==1.7.2
|
70 |
+
prometheus-fastapi-instrumentator==7.0.0
|
71 |
+
rpds-py==0.20.0
|
72 |
+
psutil==6.0.0
|
73 |
+
colorlog==6.8.2
|
74 |
+
nvidia-cufft-cu12==11.0.2.54
|
75 |
+
SQLAlchemy==2.0.35
|
76 |
+
llvmlite==0.43.0
|
77 |
+
packaging==24.1
|
78 |
+
exceptiongroup==1.2.2
|
79 |
+
dill==0.3.8
|
80 |
+
ml_dtypes==0.5.0
|
81 |
+
pyairports==2.1.1
|
82 |
+
scikit-learn==1.5.2
|
83 |
+
prettytable==3.11.0
|
84 |
+
protobuf==4.25.5
|
85 |
+
charset-normalizer==3.3.2
|
86 |
+
torchmetrics==1.4.2
|
87 |
+
text-unidecode==1.3
|
88 |
+
httpx==0.27.2
|
89 |
+
sympy==1.13.3
|
90 |
+
msgspec==0.18.6
|
91 |
+
wandb==0.18.1
|
92 |
+
backoff==2.2.1
|
93 |
+
sentencepiece==0.2.0
|
94 |
+
aiohttp==3.10.5
|
95 |
+
distro==1.9.0
|
96 |
+
lark==1.2.2
|
97 |
+
pyarrow==17.0.0
|
98 |
+
Mako==1.3.5
|
99 |
+
regex==2024.9.11
|
100 |
+
safetensors==0.4.5
|
101 |
+
aiosignal==1.3.1
|
102 |
+
jsonschema-specifications==2023.12.1
|
103 |
+
cloudpickle==3.0.0
|
104 |
+
einops==0.8.0
|
105 |
+
ray==2.36.1
|
106 |
+
fire==0.7.0
|
107 |
+
pyzmq==26.2.0
|
108 |
+
pycparser==2.22
|
109 |
+
platformdirs==4.3.6
|
110 |
+
click==8.1.7
|
111 |
+
fastapi==0.115.0
|
112 |
+
ftfy==6.3.0
|
113 |
+
torchtext==0.18.0
|
114 |
+
lm-format-enforcer==0.10.6
|
115 |
+
fsspec==2024.6.1
|
116 |
+
tzdata==2024.2
|
117 |
+
starlette==0.38.6
|
118 |
+
cycler==0.12.1
|
119 |
+
py-cpuinfo==9.0.0
|
120 |
+
h11==0.14.0
|
121 |
+
huggingface-hub==0.25.1
|
122 |
+
nvidia-cusparse-cu12==12.1.0.106
|
123 |
+
nvidia-ml-py==12.560.30
|
124 |
+
certifi==2024.8.30
|
125 |
+
httptools==0.6.1
|
126 |
+
jax==0.4.34
|
127 |
+
PyYAML==6.0.2
|
128 |
+
xxhash==3.5.0
|
129 |
+
idna==3.10
|
130 |
+
xformers==0.0.27.post2
|
131 |
+
mistral_common==1.4.3
|
132 |
+
fonttools==4.54.0
|
133 |
+
pip==23.0.1
|
134 |
+
accelerate==0.34.2
|
135 |
+
mediapipe==0.10.15
|
136 |
+
pytorch-lightning==2.4.0
|
137 |
+
ollama==0.3.3
|
138 |
+
Jinja2==3.1.4
|
139 |
+
multiprocess==0.70.16
|
140 |
+
opencv-python==4.10.0.84
|
141 |
+
termcolor==2.5.0
|
142 |
+
python-dateutil==2.9.0.post0
|
143 |
+
contourpy==1.3.0
|
144 |
+
websockets==13.1
|
145 |
+
frozenlist==1.4.1
|
146 |
+
pandas==2.2.3
|
147 |
+
networkx==3.3
|
148 |
+
diskcache==5.6.3
|
149 |
+
nvidia-cusolver-cu12==11.4.5.107
|
150 |
+
flatbuffers==24.3.25
|
151 |
+
mpmath==1.3.0
|
152 |
+
setproctitle==1.3.3
|
153 |
+
tokenizers==0.19.1
|
154 |
+
scipy==1.14.1
|
155 |
+
outlines==0.0.46
|
156 |
+
annotated-types==0.7.0
|
157 |
+
docker-pycreds==0.4.0
|
158 |
+
magicattr==0.1.6
|
159 |
+
wcwidth==0.2.13
|
160 |
+
pytorch-metric-learning==2.6.0
|
161 |
+
datasets==3.0.0
|
162 |
+
gitdb==4.0.11
|
163 |
+
lora-diffusion==0.1.7
|
164 |
+
referencing==0.35.1
|
165 |
+
python-slugify==8.0.4
|
166 |
+
zipp==3.20.2
|
167 |
+
triton==3.0.0
|
168 |
+
absl-py==2.1.0
|
169 |
+
threadpoolctl==3.5.0
|
170 |
+
uvloop==0.20.0
|
171 |
+
tiktoken==0.7.0
|
172 |
+
pytz==2024.2
|
173 |
+
nest-asyncio==1.6.0
|
174 |
+
nvidia-cublas-cu12==12.1.3.1
|
175 |
+
litellm==1.48.12
|
176 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
177 |
+
greenlet==3.1.1
|
178 |
+
alembic==1.13.3
|
wandb/run-20241028_090044-f9fzz8iy/files/wandb-metadata.json
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"os": "Linux-5.15.161-ql-generic-13.0-14-x86_64-with-glibc2.35",
|
3 |
+
"python": "3.10.13",
|
4 |
+
"startedAt": "2024-10-28T05:00:44.986886Z",
|
5 |
+
"program": "/home/siddharth.tourani/Kyrylo/nlp/main.py",
|
6 |
+
"codePath": "main.py",
|
7 |
+
"email": "kyrylo.shyvam@students.iiit.ac.in",
|
8 |
+
"root": ".",
|
9 |
+
"host": "gpu-08",
|
10 |
+
"username": "siddharth.tourani",
|
11 |
+
"executable": "/home/siddharth.tourani/Minimal/bin/python3",
|
12 |
+
"codePathLocal": "main.py",
|
13 |
+
"cpu_count": 128,
|
14 |
+
"cpu_count_logical": 256,
|
15 |
+
"gpu": "[NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB]",
|
16 |
+
"gpu_count": 4,
|
17 |
+
"disk": {
|
18 |
+
"/": {
|
19 |
+
"total": "1073741824",
|
20 |
+
"used": "21049344"
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"memory": {
|
24 |
+
"total": "270256893952"
|
25 |
+
},
|
26 |
+
"cpu": {
|
27 |
+
"count": 128,
|
28 |
+
"countLogical": 256
|
29 |
+
},
|
30 |
+
"gpu_nvidia": [
|
31 |
+
{
|
32 |
+
"name": "NVIDIA A100-SXM4-40GB",
|
33 |
+
"memoryTotal": "42949672960",
|
34 |
+
"cudaCores": 6912,
|
35 |
+
"architecture": "Ampere"
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"name": "NVIDIA A100-SXM4-40GB",
|
39 |
+
"memoryTotal": "42949672960",
|
40 |
+
"cudaCores": 6912,
|
41 |
+
"architecture": "Ampere"
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"name": "NVIDIA A100-SXM4-40GB",
|
45 |
+
"memoryTotal": "42949672960",
|
46 |
+
"cudaCores": 6912,
|
47 |
+
"architecture": "Ampere"
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"name": "NVIDIA A100-SXM4-40GB",
|
51 |
+
"memoryTotal": "42949672960",
|
52 |
+
"cudaCores": 6912,
|
53 |
+
"architecture": "Ampere"
|
54 |
+
}
|
55 |
+
],
|
56 |
+
"slurm": {
|
57 |
+
"job_id": "68153"
|
58 |
+
},
|
59 |
+
"cudaVersion": "12.4"
|
60 |
+
}
|
wandb/run-20241028_090044-f9fzz8iy/files/wandb-summary.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"_wandb":{"runtime":30},"_timestamp":1.7300916668905182e+09,"_runtime":21.903801157,"_step":1,"Training loss_step":104.03557586669922,"epoch":0,"trainer/global_step":39}
|
wandb/run-20241028_090044-f9fzz8iy/logs/debug-core.log
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"time":"2024-10-28T09:00:44.271373064+04:00","level":"INFO","msg":"started logging, with flags","port-filename":"/tmp/tmpr43jmo4e/port-1243414.txt","pid":1243414,"debug":false,"disable-analytics":false}
|
2 |
+
{"time":"2024-10-28T09:00:44.271400724+04:00","level":"INFO","msg":"FeatureState","shutdownOnParentExitEnabled":false}
|
3 |
+
{"time":"2024-10-28T09:00:44.272333799+04:00","level":"INFO","msg":"Will exit if parent process dies.","ppid":1243414}
|
4 |
+
{"time":"2024-10-28T09:00:44.27232289+04:00","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":36087,"Zone":""}}
|
5 |
+
{"time":"2024-10-28T09:00:44.467138216+04:00","level":"INFO","msg":"created new connection","id":"127.0.0.1:52740"}
|
6 |
+
{"time":"2024-10-28T09:00:44.988394668+04:00","level":"INFO","msg":"connection init received","streamId":"f9fzz8iy","id":"127.0.0.1:52740"}
|
7 |
+
{"time":"2024-10-28T09:00:44.989450074+04:00","level":"ERROR","msg":"error creating symlink","error":"symlink /home/siddharth.tourani/.cache/wandb/logs/core-debug-20241028_090044.log wandb/run-20241028_090044-f9fzz8iy/logs/debug-core.log: file exists"}
|
8 |
+
{"time":"2024-10-28T09:00:44.994677699+04:00","level":"INFO","msg":"connection init completed","streamId":"f9fzz8iy","id":"127.0.0.1:52740"}
|
9 |
+
{"time":"2024-10-28T09:01:15.297081483+04:00","level":"INFO","msg":"connection: teardown","id":"127.0.0.1:52740"}
|
10 |
+
{"time":"2024-10-28T09:01:15.297356992+04:00","level":"INFO","msg":"server is shutting down"}
|
11 |
+
{"time":"2024-10-28T09:01:15.297394512+04:00","level":"INFO","msg":"closed connection","id":"127.0.0.1:52740"}
|
12 |
+
{"time":"2024-10-28T09:01:16.667813027+04:00","level":"INFO","msg":"connection closed","id":"127.0.0.1:52740"}
|
13 |
+
{"time":"2024-10-28T09:01:16.667830647+04:00","level":"INFO","msg":"server is closed"}
|
wandb/run-20241028_090044-f9fzz8iy/logs/debug-internal.log
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"time":"2024-10-28T09:00:44.989289054+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
|
2 |
+
{"time":"2024-10-28T09:00:44.989302624+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241028_090044-f9fzz8iy/logs/debug-core.log"}
|
3 |
+
{"time":"2024-10-28T09:00:44.989755452+04:00","level":"INFO","msg":"using version","core version":"0.18.1"}
|
4 |
+
{"time":"2024-10-28T09:00:44.989762312+04:00","level":"INFO","msg":"created symlink","path":"wandb/run-20241028_090044-f9fzz8iy/logs/debug-core.log"}
|
5 |
+
{"time":"2024-10-28T09:00:44.994658889+04:00","level":"INFO","msg":"created new stream","id":"f9fzz8iy"}
|
6 |
+
{"time":"2024-10-28T09:00:44.994674479+04:00","level":"INFO","msg":"stream: started","id":"f9fzz8iy"}
|
7 |
+
{"time":"2024-10-28T09:00:44.994693389+04:00","level":"INFO","msg":"sender: started","stream_id":{"value":"f9fzz8iy"}}
|
8 |
+
{"time":"2024-10-28T09:00:44.994713418+04:00","level":"INFO","msg":"handler: started","stream_id":{"value":"f9fzz8iy"}}
|
9 |
+
{"time":"2024-10-28T09:00:44.994695909+04:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"f9fzz8iy"}}
|
10 |
+
{"time":"2024-10-28T09:00:45.712012987+04:00","level":"INFO","msg":"wandb-core","!BADKEY":null}
|
11 |
+
{"time":"2024-10-28T09:00:45.713870838+04:00","level":"INFO","msg":"Starting system monitor"}
|
12 |
+
{"time":"2024-10-28T09:00:45.716597735+04:00","level":"ERROR","msg":"git repo not found","error":"repository does not exist"}
|
13 |
+
{"time":"2024-10-28T09:01:15.297352912+04:00","level":"INFO","msg":"stream: closing","id":"f9fzz8iy"}
|
14 |
+
{"time":"2024-10-28T09:01:15.297391541+04:00","level":"INFO","msg":"Stopping system monitor"}
|
15 |
+
{"time":"2024-10-28T09:01:15.2976609+04:00","level":"INFO","msg":"Stopped system monitor"}
|
16 |
+
{"time":"2024-10-28T09:01:15.570782955+04:00","level":"WARN","msg":"No job ingredients found, not creating job artifact"}
|
17 |
+
{"time":"2024-10-28T09:01:15.570791505+04:00","level":"WARN","msg":"No source type found, not creating job artifact"}
|
18 |
+
{"time":"2024-10-28T09:01:15.570794545+04:00","level":"INFO","msg":"sender: sendDefer: no job artifact to save"}
|
19 |
+
{"time":"2024-10-28T09:01:16.667000401+04:00","level":"INFO","msg":"handler: closed","stream_id":{"value":"f9fzz8iy"}}
|
20 |
+
{"time":"2024-10-28T09:01:16.667029211+04:00","level":"INFO","msg":"writer: Close: closed","stream_id":{"value":"f9fzz8iy"}}
|
21 |
+
{"time":"2024-10-28T09:01:16.667090251+04:00","level":"INFO","msg":"sender: closed","stream_id":{"value":"f9fzz8iy"}}
|
22 |
+
{"time":"2024-10-28T09:01:16.667475389+04:00","level":"INFO","msg":"stream: closed","id":"f9fzz8iy"}
|
wandb/run-20241028_090044-f9fzz8iy/logs/debug.log
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Current SDK version is 0.18.1
|
2 |
+
2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Configure stats pid to 1243414
|
3 |
+
2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/.config/wandb/settings
|
4 |
+
2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Loading settings from /home/siddharth.tourani/Kyrylo/nlp/wandb/settings
|
5 |
+
2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Loading settings from environment variables: {}
|
6 |
+
2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Applying setup settings: {'mode': None, '_disable_service': None}
|
7 |
+
2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Inferring run settings from compute environment: {'program_relpath': 'main.py', 'program_abspath': '/home/siddharth.tourani/Kyrylo/nlp/main.py', 'program': '/home/siddharth.tourani/Kyrylo/nlp/main.py'}
|
8 |
+
2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_setup.py:_flush():77] Applying login settings: {}
|
9 |
+
2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_init.py:_log_setup():532] Logging user logs to ./wandb/run-20241028_090044-f9fzz8iy/logs/debug.log
|
10 |
+
2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_init.py:_log_setup():533] Logging internal logs to ./wandb/run-20241028_090044-f9fzz8iy/logs/debug-internal.log
|
11 |
+
2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_init.py:init():616] calling init triggers
|
12 |
+
2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_init.py:init():623] wandb.init called with sweep_config: {}
|
13 |
+
config: {}
|
14 |
+
2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_init.py:init():666] starting backend
|
15 |
+
2024-10-28 09:00:44,984 INFO MainThread:1243414 [wandb_init.py:init():670] setting up manager
|
16 |
+
2024-10-28 09:00:44,985 INFO MainThread:1243414 [backend.py:_multiprocessing_setup():105] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
|
17 |
+
2024-10-28 09:00:44,986 INFO MainThread:1243414 [wandb_init.py:init():678] backend started and connected
|
18 |
+
2024-10-28 09:00:44,989 INFO MainThread:1243414 [wandb_init.py:init():773] updated telemetry
|
19 |
+
2024-10-28 09:00:44,989 INFO MainThread:1243414 [wandb_init.py:init():806] communicating run to backend with 90.0 second timeout
|
20 |
+
2024-10-28 09:00:45,707 INFO MainThread:1243414 [wandb_init.py:init():857] starting run threads in backend
|
21 |
+
2024-10-28 09:00:45,884 INFO MainThread:1243414 [wandb_run.py:_console_start():2459] atexit reg
|
22 |
+
2024-10-28 09:00:45,884 INFO MainThread:1243414 [wandb_run.py:_redirect():2307] redirect: wrap_raw
|
23 |
+
2024-10-28 09:00:45,884 INFO MainThread:1243414 [wandb_run.py:_redirect():2372] Wrapping output streams.
|
24 |
+
2024-10-28 09:00:45,884 INFO MainThread:1243414 [wandb_run.py:_redirect():2397] Redirects installed.
|
25 |
+
2024-10-28 09:00:45,885 INFO MainThread:1243414 [wandb_init.py:init():900] run started, returning control to user process
|
26 |
+
2024-10-28 09:00:45,983 INFO MainThread:1243414 [wandb_run.py:_config_callback():1388] config_cb None None {'lr': 0.0001}
|
27 |
+
2024-10-28 09:01:15,297 WARNING MsgRouterThr:1243414 [router.py:message_loop():77] message_loop has been closed
|
wandb/run-20241028_090044-f9fzz8iy/run-f9fzz8iy.wandb
ADDED
Binary file (60.2 kB). View file
|
|
wandb/run-20241028_090149-4jbvn26d/files/code/main.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import inf
|
2 |
+
# from utils import *
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import torch.utils
|
7 |
+
import torch.utils.data
|
8 |
+
# from utils import MyDataset, custom_collate
|
9 |
+
from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
|
10 |
+
import wandb
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
import einops
|
14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
|
15 |
+
|
16 |
+
np.random.seed(123)
|
17 |
+
torch.manual_seed(123)
|
18 |
+
torch.cuda.random.manual_seed(123)
|
19 |
+
|
20 |
+
import lightning as L
|
21 |
+
import utils
|
22 |
+
|
23 |
+
|
24 |
+
class PromptTuningModel(nn.Module):
|
25 |
+
def __init__(self, num_prompts=6):
|
26 |
+
super().__init__()
|
27 |
+
self.num_prompts = num_prompts
|
28 |
+
|
29 |
+
self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
|
30 |
+
self.model.requires_grad_(False)
|
31 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
32 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
33 |
+
|
34 |
+
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
|
35 |
+
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
|
36 |
+
|
37 |
+
self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer), pad_to_multiple_of=128)
|
38 |
+
|
39 |
+
tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
|
40 |
+
token_embedding = self.model.transformer.wte(tmp[0])
|
41 |
+
self.token_embedding = token_embedding
|
42 |
+
for _ in range(num_prompts//3-1):
|
43 |
+
self.token_embedding = torch.cat([self.token_embedding, token_embedding])
|
44 |
+
|
45 |
+
# print(self.token_embedding.shape)
|
46 |
+
data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
|
47 |
+
self.learnable_prompt = nn.Parameter(data, requires_grad=True)
|
48 |
+
|
49 |
+
# @torch.compile
|
50 |
+
def forward(self, X):
|
51 |
+
self.learnable_prompt = self.learnable_prompt.to(X.device)
|
52 |
+
embeddings = self.model.transformer.wte(X, ) # b s d
|
53 |
+
|
54 |
+
embeddings = torch.cat([embeddings, self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1)], dim=1)
|
55 |
+
mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
|
56 |
+
# print(mask.shape)
|
57 |
+
logits = self.model(inputs_embeds = embeddings, attention_mask=mask).logits[:,self.num_prompts:].swapaxes(1,2)
|
58 |
+
return logits
|
59 |
+
|
60 |
+
class LitModelPromptTuning(L.LightningModule):
|
61 |
+
def __init__(self, model, lr=1e-4):
|
62 |
+
super().__init__()
|
63 |
+
self.model = model
|
64 |
+
self.lr = lr
|
65 |
+
|
66 |
+
self.save_hyperparameters(ignore=['model'])
|
67 |
+
|
68 |
+
|
69 |
+
def training_step(self, batch, batch_idx):
|
70 |
+
X, y = batch
|
71 |
+
# for i,j in zip(X[0], y[0]):
|
72 |
+
# print(i.item(),j.item())
|
73 |
+
# print(self.model.pad, self.model.eot)
|
74 |
+
logits = self.model(X)
|
75 |
+
loss = F.cross_entropy(logits, target=y, ignore_index=50257)
|
76 |
+
# print(loss)
|
77 |
+
# exit()
|
78 |
+
|
79 |
+
self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
|
80 |
+
return loss
|
81 |
+
|
82 |
+
|
83 |
+
def validation_step(self, batch, batch_idx):
|
84 |
+
X, y = batch
|
85 |
+
|
86 |
+
logits = self.model(X)
|
87 |
+
loss = F.cross_entropy(logits, target=y, ignore_index=50257)
|
88 |
+
|
89 |
+
self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
90 |
+
return loss
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
def configure_optimizers(self):
|
95 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
|
96 |
+
return optimizer
|
97 |
+
|
98 |
+
from lightning.pytorch.loggers import WandbLogger
|
99 |
+
if __name__ == '__main__':
|
100 |
+
torch.set_float32_matmul_precision('medium')
|
101 |
+
dl_train, dl_val, dl_test = utils.import_data(bs=25, fraction=0.1)
|
102 |
+
gpt_model = PromptTuningModel()
|
103 |
+
gpt_model = torch.compile(gpt_model)
|
104 |
+
model = LitModelPromptTuning(model=gpt_model)
|
105 |
+
print('Training')
|
106 |
+
|
107 |
+
logger = WandbLogger(project='Anlp-3')
|
108 |
+
trainer = L.Trainer(
|
109 |
+
accelerator='gpu',
|
110 |
+
# strategy='auto',
|
111 |
+
# strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
|
112 |
+
devices=[3],
|
113 |
+
default_root_dir=f'./logs/', # Tensorflow can be used to viz
|
114 |
+
num_nodes=1,
|
115 |
+
num_sanity_val_steps=1, # runs a validation step before stating training
|
116 |
+
precision='16-mixed', # we use half precision to reduce memory usage
|
117 |
+
max_epochs=10,
|
118 |
+
check_val_every_n_epoch=1, # run validation every epoch
|
119 |
+
log_every_n_steps=20,
|
120 |
+
logger=logger
|
121 |
+
# detect_anomaly=True,
|
122 |
+
)
|
123 |
+
|
124 |
+
trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
|