sade-adrien
commited on
Commit
•
15e8d2f
1
Parent(s):
6b9513a
Upload 2 files
Browse files
mapping_adapter_checkpoint_114000steps.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cb6e95ae3b9cd81f6d5bdcad0387c65aa804ec172336db5f89ba7ad7ffc1f8d2
|
3 |
+
size 125866547
|
representation_mapping.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig, AdamW, get_linear_schedule_with_warmup
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
import transformers
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
from datasets import load_dataset, DatasetDict
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch
|
8 |
+
import wandb
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
args_max_epoch = 1
|
12 |
+
args_batch_size = 64
|
13 |
+
args_learning_rate = 3e-5
|
14 |
+
args_num_warmup_steps = 100
|
15 |
+
args_gradient_accumulation_steps_default = 2
|
16 |
+
adapter_hidden_dim = 4096
|
17 |
+
|
18 |
+
device = 'cuda'
|
19 |
+
|
20 |
+
|
21 |
+
def main():
|
22 |
+
wandb.init(project="MappingAdapater_training_v6", name="training_run")
|
23 |
+
|
24 |
+
model = MappingStructure(checkpointE = "sentence-transformers/stsb-roberta-large",
|
25 |
+
checkpointD = "mistralai/Mistral-7B-Instruct-v0.1",
|
26 |
+
hidden_dim = adapter_hidden_dim,
|
27 |
+
torch_dtype = torch.float16,
|
28 |
+
flash_attn = True,
|
29 |
+
).to(device)
|
30 |
+
|
31 |
+
for n,p in model.named_parameters():
|
32 |
+
if 'mapping' not in n:
|
33 |
+
p.requires_grad = False
|
34 |
+
else:
|
35 |
+
p.requires_grad = True
|
36 |
+
|
37 |
+
dataset = load_dataset("sade-adrien/redpajama_v2_sample_10M")['train']
|
38 |
+
train_dataset, val_dataset = split_dataset(dataset, train_size=.989333)
|
39 |
+
datasets = DatasetDict({
|
40 |
+
'train': train_dataset,
|
41 |
+
'val': val_dataset
|
42 |
+
})
|
43 |
+
|
44 |
+
train_dataloader = DataLoader(datasets['train'], batch_size=args_batch_size, shuffle=True)
|
45 |
+
val_dataloader = DataLoader(datasets['val'], batch_size=args_batch_size, shuffle=False)
|
46 |
+
|
47 |
+
optimizer = AdamW(model.parameters(), lr=args_learning_rate)
|
48 |
+
scheduler = get_linear_schedule_with_warmup(optimizer, args_num_warmup_steps, args_max_epoch*len(train_dataloader))
|
49 |
+
|
50 |
+
global_step = 0
|
51 |
+
for epoch in range(args_max_epoch):
|
52 |
+
train_dataloader = DataLoader(datasets['train'], batch_size=args_batch_size, shuffle=True, worker_init_fn=lambda _: torch.manual_seed(epoch))
|
53 |
+
|
54 |
+
for batch in tqdm(train_dataloader):
|
55 |
+
input_prompt = batch['raw_content']
|
56 |
+
outputs = model(input_prompt=input_prompt, compute_loss=True)
|
57 |
+
loss = outputs['loss']
|
58 |
+
|
59 |
+
# Gradient accumulation
|
60 |
+
loss = loss / args_gradient_accumulation_steps_default
|
61 |
+
loss.backward()
|
62 |
+
|
63 |
+
if (global_step + 1) % args_gradient_accumulation_steps_default == 0:
|
64 |
+
optimizer.step()
|
65 |
+
optimizer.zero_grad()
|
66 |
+
scheduler.step()
|
67 |
+
|
68 |
+
|
69 |
+
if (global_step + 1) % 2000 == 0:
|
70 |
+
torch.save({
|
71 |
+
'epoch': epoch,
|
72 |
+
'mapping_state_dict': model.mapping.state_dict(),
|
73 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
74 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
75 |
+
'global_step': global_step,
|
76 |
+
}, f'models/mapping_adapter_checkpoint_{global_step + 1}steps.pth')
|
77 |
+
|
78 |
+
global_step += 1
|
79 |
+
val_loss = None
|
80 |
+
if (global_step + 1) % 8000 == 0:
|
81 |
+
model.eval()
|
82 |
+
val_loss = 0.0
|
83 |
+
with torch.no_grad():
|
84 |
+
for val_batch in tqdm(val_dataloader):
|
85 |
+
val_inputs = val_batch['raw_content']
|
86 |
+
val_outputs = model(input_prompt=val_inputs, compute_loss=True)
|
87 |
+
val_loss += val_outputs['loss']
|
88 |
+
val_loss /= len(val_dataloader)
|
89 |
+
|
90 |
+
model.train()
|
91 |
+
|
92 |
+
wandb.log({
|
93 |
+
'step': global_step + 1,
|
94 |
+
'learning_rate': scheduler.get_last_lr()[0],
|
95 |
+
'train_loss': loss.item() * args_gradient_accumulation_steps_default,
|
96 |
+
'val_loss': val_loss.item() if val_loss else None
|
97 |
+
})
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
def split_dataset(dataset, train_size=.9):
|
103 |
+
index = int(len(dataset) * train_size)
|
104 |
+
return dataset.select(range(index)), dataset.select(range(index, len(dataset)))
|
105 |
+
|
106 |
+
class MappingAdapter(nn.Module):
|
107 |
+
def __init__(self, input_dim, output_dim, hidden_dim):
|
108 |
+
super(MappingAdapter, self).__init__()
|
109 |
+
self.layer1 = nn.Linear(input_dim, hidden_dim)
|
110 |
+
self.layer2 = nn.Linear(hidden_dim, output_dim)
|
111 |
+
self.activation = nn.LeakyReLU(.01)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
x = self.layer1(x)
|
115 |
+
x = self.activation(x)
|
116 |
+
x = self.layer2(x)
|
117 |
+
return x
|
118 |
+
|
119 |
+
class MappingStructure(nn.Module):
|
120 |
+
def __init__(self, checkpointE, checkpointD, hidden_dim=2048, torch_dtype=torch.float32, flash_attn=False):
|
121 |
+
super(MappingStructure, self).__init__()
|
122 |
+
|
123 |
+
self.configE = AutoConfig.from_pretrained(checkpointE)
|
124 |
+
self.Encoder = AutoModel.from_pretrained(checkpointE,
|
125 |
+
low_cpu_mem_usage = True,
|
126 |
+
torch_dtype = torch_dtype,
|
127 |
+
config = self.configE
|
128 |
+
)
|
129 |
+
|
130 |
+
self.configD = AutoConfig.from_pretrained(checkpointD)
|
131 |
+
if flash_attn:
|
132 |
+
self.configD.update({'_flash_attn_2_enabled' : True})
|
133 |
+
self.Decoder = AutoModel.from_pretrained(checkpointD,
|
134 |
+
low_cpu_mem_usage = True,
|
135 |
+
torch_dtype = torch_dtype,
|
136 |
+
config = self.configD
|
137 |
+
)
|
138 |
+
|
139 |
+
self.mapping = MappingAdapter(self.configD.hidden_size, self.configE.hidden_size, hidden_dim=hidden_dim).to(torch_dtype)
|
140 |
+
|
141 |
+
self._init_tokenizers(checkpointE, checkpointD)
|
142 |
+
|
143 |
+
def _init_tokenizers(self, checkpointE, checkpointD):
|
144 |
+
self.tokenizerE = AutoTokenizer.from_pretrained(checkpointE, use_fast = False, revision = 'main', config = self.configE, padding_side='left')
|
145 |
+
self.tokenizerD = AutoTokenizer.from_pretrained(checkpointD, use_fast = False, revision = 'main', config = self.configD, padding_side='left')
|
146 |
+
self.tokenizerD.pad_token_id = self.tokenizerD.unk_token_id
|
147 |
+
|
148 |
+
def cosine_sim(self, u, v):
|
149 |
+
assert u.shape == v.shape, "u and v must have the same shape"
|
150 |
+
u_normalized = u / torch.norm(u, dim=1, keepdim=True)
|
151 |
+
v_normalized = v / torch.norm(v, dim=1, keepdim=True)
|
152 |
+
|
153 |
+
# Compute cosine similarity using dot product
|
154 |
+
return torch.sum(u_normalized * v_normalized, dim=1)
|
155 |
+
|
156 |
+
|
157 |
+
def mean_pooling(self, hidden_state, attention_mask):
|
158 |
+
token_embeddings = hidden_state
|
159 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
160 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
161 |
+
|
162 |
+
|
163 |
+
def build_batch(self, input_prompt):
|
164 |
+
size = torch.randint(1, self.configE.max_position_embeddings-2, (1,)).item()
|
165 |
+
targets = []
|
166 |
+
|
167 |
+
for prompt in input_prompt:
|
168 |
+
tokenized_input = self.tokenizerE(prompt)
|
169 |
+
tokenized_input = {'input_ids': tokenized_input['input_ids'][:size],
|
170 |
+
'attention_mask': tokenized_input['attention_mask'][:size],
|
171 |
+
|
172 |
+
}
|
173 |
+
targets.append(tokenized_input)
|
174 |
+
targets = self.tokenizerE.pad(targets, padding=True, return_tensors='pt')
|
175 |
+
|
176 |
+
return targets
|
177 |
+
|
178 |
+
|
179 |
+
def forward(self, input_prompt, compute_loss=False):
|
180 |
+
loss = None
|
181 |
+
|
182 |
+
# Slice prompt of needed to fit encoder max position embeddings (hard constraint)
|
183 |
+
if not compute_loss:
|
184 |
+
inputs = self.tokenizerD(input_prompt, return_tensors='pt', padding=True).to(device)
|
185 |
+
|
186 |
+
hidden_state_D = self.Decoder(**inputs).last_hidden_state
|
187 |
+
hidden_state_D_mapped = self.mapping(hidden_state_D)
|
188 |
+
|
189 |
+
else:
|
190 |
+
targets = self.build_batch(input_prompt).to(device)
|
191 |
+
|
192 |
+
input_prompt_sliced = self.tokenizerE.batch_decode(targets['input_ids'], skip_special_tokens=True)
|
193 |
+
inputs = self.tokenizerD(input_prompt_sliced, return_tensors='pt', padding=True).to(device)
|
194 |
+
|
195 |
+
hidden_state_D = self.Decoder(**inputs).last_hidden_state
|
196 |
+
hidden_state_D_mapped = self.mapping(hidden_state_D)
|
197 |
+
|
198 |
+
hidden_state_E = self.Encoder(**targets).last_hidden_state
|
199 |
+
|
200 |
+
proj_E = self.mean_pooling(hidden_state_E, targets['attention_mask'])
|
201 |
+
proj_D = self.mean_pooling(hidden_state_D_mapped, inputs['attention_mask'])
|
202 |
+
|
203 |
+
loss = 1 - torch.mean(self.cosine_sim(proj_E, proj_D))
|
204 |
+
|
205 |
+
del inputs
|
206 |
+
del targets
|
207 |
+
del input_prompt_sliced
|
208 |
+
del hidden_state_E
|
209 |
+
del proj_E
|
210 |
+
del proj_D
|
211 |
+
torch.cuda.empty_cache()
|
212 |
+
|
213 |
+
return {'loss': loss,
|
214 |
+
'last_hidden_state': hidden_state_D,
|
215 |
+
'last_hidden_state_mapped': hidden_state_D_mapped,
|
216 |
+
}
|
217 |
+
|
218 |
+
|
219 |
+
if __name__ == '__main__':
|
220 |
+
main()
|