pt-sk commited on
Commit
ee1779a
·
verified ·
1 Parent(s): 3f0e944

Create train.py

Browse files
Files changed (1) hide show
  1. code/train.py +28 -0
code/train.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+
4
+ iterator = tqdm(dataloader, desc="Training", postfix={"train_loss":0.0})
5
+
6
+ for item in iterator:
7
+ item = tokenizer.bos_token + " " + item[0] + " " + tokenizer.eos_token
8
+ encoded_inp = tokenizer(item, return_tensors='pt').input_ids.to("cuda")
9
+ logits = mamba_model(encoded_inp)
10
+
11
+ labels = encoded_inp.to(logits.device)
12
+ shift_logits = logits[:, :-1, :].contiguous()
13
+ labels = labels[:, 1:].contiguous()
14
+ loss_fct = torch.nn.CrossEntropyLoss()
15
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
16
+
17
+ optimizer.zero_grad(set_to_none=True)
18
+ loss.backward()
19
+ optimizer.step()
20
+
21
+ # moving data's from gpu to cpu
22
+ loss = loss.detach().cpu().numpy()
23
+ logits = logits.detach().cpu().numpy()
24
+ labels = labels.detach().cpu().numpy()
25
+ encoded_inp = encoded_inp.detach().cpu().numpy()
26
+ shift_logits = shift_logits.detach().cpu().numpy()
27
+
28
+ iterator.set_postfix({"train_loss": loss.item()}, refresh=False)