mamba / code /train.py
pt-sk's picture
Create train.py
ee1779a verified
raw
history blame
968 Bytes
import torch
from tqdm import tqdm
iterator = tqdm(dataloader, desc="Training", postfix={"train_loss":0.0})
for item in iterator:
item = tokenizer.bos_token + " " + item[0] + " " + tokenizer.eos_token
encoded_inp = tokenizer(item, return_tensors='pt').input_ids.to("cuda")
logits = mamba_model(encoded_inp)
labels = encoded_inp.to(logits.device)
shift_logits = logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
# moving data's from gpu to cpu
loss = loss.detach().cpu().numpy()
logits = logits.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
encoded_inp = encoded_inp.detach().cpu().numpy()
shift_logits = shift_logits.detach().cpu().numpy()
iterator.set_postfix({"train_loss": loss.item()}, refresh=False)