Create train.py
Browse files- code/train.py +28 -0
code/train.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
+
iterator = tqdm(dataloader, desc="Training", postfix={"train_loss":0.0})
|
5 |
+
|
6 |
+
for item in iterator:
|
7 |
+
item = tokenizer.bos_token + " " + item[0] + " " + tokenizer.eos_token
|
8 |
+
encoded_inp = tokenizer(item, return_tensors='pt').input_ids.to("cuda")
|
9 |
+
logits = mamba_model(encoded_inp)
|
10 |
+
|
11 |
+
labels = encoded_inp.to(logits.device)
|
12 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
13 |
+
labels = labels[:, 1:].contiguous()
|
14 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
15 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
|
16 |
+
|
17 |
+
optimizer.zero_grad(set_to_none=True)
|
18 |
+
loss.backward()
|
19 |
+
optimizer.step()
|
20 |
+
|
21 |
+
# moving data's from gpu to cpu
|
22 |
+
loss = loss.detach().cpu().numpy()
|
23 |
+
logits = logits.detach().cpu().numpy()
|
24 |
+
labels = labels.detach().cpu().numpy()
|
25 |
+
encoded_inp = encoded_inp.detach().cpu().numpy()
|
26 |
+
shift_logits = shift_logits.detach().cpu().numpy()
|
27 |
+
|
28 |
+
iterator.set_postfix({"train_loss": loss.item()}, refresh=False)
|