import torch from tqdm import tqdm iterator = tqdm(dataloader, desc="Training", postfix={"train_loss":0.0}) for item in iterator: item = tokenizer.bos_token + " " + item[0] + " " + tokenizer.eos_token encoded_inp = tokenizer(item, return_tensors='pt').input_ids.to("cuda") logits = mamba_model(encoded_inp) labels = encoded_inp.to(logits.device) shift_logits = logits[:, :-1, :].contiguous() labels = labels[:, 1:].contiguous() loss_fct = torch.nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() # moving data's from gpu to cpu loss = loss.detach().cpu().numpy() logits = logits.detach().cpu().numpy() labels = labels.detach().cpu().numpy() encoded_inp = encoded_inp.detach().cpu().numpy() shift_logits = shift_logits.detach().cpu().numpy() iterator.set_postfix({"train_loss": loss.item()}, refresh=False)