Crystalcareai commited on
Commit
fba2fba
1 Parent(s): 351d904

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +9 -58
modeling_quiet.py CHANGED
@@ -18,9 +18,7 @@
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
  """ PyTorch Quiet model."""
21
- import inspect
22
  import math
23
- import pdb
24
  import warnings
25
  from collections import defaultdict
26
  from typing import List, Optional, Tuple, Union
@@ -31,8 +29,7 @@ import torch.utils.checkpoint
31
  from torch import nn
32
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
  from transformers.generation.utils import GenerationMixin
34
- from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
35
- from transformers import TextStreamer, AutoTokenizer
36
  import transformers
37
 
38
  from transformers.activations import ACT2FN
@@ -43,8 +40,6 @@ from transformers.modeling_utils import PreTrainedModel
43
  from transformers.utils import (
44
  add_start_docstrings,
45
  add_start_docstrings_to_model_forward,
46
- is_flash_attn_2_available,
47
- is_flash_attn_greater_or_equal_2_10,
48
  logging,
49
  replace_return_docstrings,
50
  )
@@ -240,7 +235,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
240
  num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
241
  """
242
 
243
- # pdb.set_trace()
244
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
245
  if n_rep == 1:
246
  return hidden_states
@@ -332,7 +326,7 @@ class QuietAttention(nn.Module):
332
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
333
 
334
  if past_key_value is not None:
335
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
336
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
337
 
338
  # repeat k/v heads if n_kv_heads < n_heads
@@ -377,8 +371,7 @@ class QuietAttention(nn.Module):
377
  )
378
 
379
  attn_weights = attn_weights + attention_mask
380
-
381
- # upcast attention to fp32
382
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
383
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
384
  attn_output = torch.matmul(attn_weights, value_states)
@@ -851,16 +844,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
851
  self.model = QuietModel(config)
852
  self.vocab_size = config.vocab_size
853
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
854
- # self.router_aux_loss_coef = config.router_aux_loss_coef
855
- # self.num_experts = config.num_experts
856
- # self.num_experts_per_tok = config.num_experts_per_tok
857
  self.max_thoughts = config.max_thoughts
858
  self.merged_lm_and_talk_heads = config.merged_lm_and_talk_heads
859
  self.use_concat_talk_head = config.use_concat_talk_head
860
  self.use_shallow_talk = config.use_shallow_talk
861
  self.use_complex_talk_head = config.use_complex_talk_head
862
  self.use_weighted_talk_head = config.use_weighted_talk_head
863
- # the weighted head will output a single value, so it can't be passed to the lm head
864
  assert not (self.use_weighted_talk_head and self.use_shallow_talk)
865
 
866
  self.n_ahead = 1
@@ -931,7 +920,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
931
  self.thinking_threshold = 0.5
932
  self.thinking_usefulness_loss_weight = 1e-2
933
 
934
- # Not used in the paper:
935
  self.use_thought_prefix = False
936
  self.use_reparam_for_thought_embeddings = False
937
  self.use_upper_triangular = False
@@ -939,7 +927,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
939
  self.comparison_mode = False
940
  self.gumbel_detach = False
941
 
942
- # For visualization
943
  self.eval_mode = False
944
 
945
  num_talk = 1
@@ -968,7 +955,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
968
  # Add dropout regularization
969
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
970
 
971
- # Initialize weights and apply final processing
972
  self.post_init()
973
 
974
  def get_input_embeddings(self):
@@ -1219,20 +1205,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1219
  n_passes_to_restore = self.n_passes
1220
  self.n_ahead_talk = 1
1221
  self.n_passes = 1
1222
-
1223
- # aux_loss = None
1224
- # output_router_logits = output_router_logits if output_router_logits is not None else self.config.output_router_logits
1225
- # if output_router_logits:
1226
- # router_logits = outputs.router_logits if return_dict else outputs[-1]
1227
- # if router_logits is not None:
1228
- # aux_loss = load_balancing_loss_func(
1229
- # router_logits,
1230
- # self.num_experts,
1231
- # self.num_experts_per_tok,
1232
- # attention_mask,
1233
- # )
1234
- # if labels is not None:
1235
- # loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
1236
  if input_ids.dim() == 1:
1237
  input_ids = input_ids.unsqueeze(0)
1238
  attention_mask = attention_mask.unsqueeze(0) if attention_mask is not None else None
@@ -1300,7 +1272,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1300
  self.start_token_id = self.tokenizer.bos_token_id
1301
  self.tokenizer_has_start_thought_token = False
1302
  elif self.use_start_thought_token:
1303
- # base_start_id = self.tokenizer.convert_tokens_to_ids(self.initial_start_token)
1304
  base_start_id = self.tokenizer.encode(self.initial_start_token, add_special_tokens=False)[0]
1305
  if self.initialize_thought_embedding_to_normal:
1306
  self.start_embedding.data = torch.zeros_like(self.start_embedding.data)
@@ -1313,7 +1284,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1313
  self.end_token_id = self.tokenizer.eos_token_id
1314
  self.tokenizer_has_end_thought_token = False
1315
  elif self.use_end_thought_token:
1316
- # base_end_id = self.tokenizer.convert_tokens_to_ids(self.initial_end_token)
1317
  base_end_id = self.tokenizer.encode(self.initial_end_token, add_special_tokens=False)[0]
1318
  if self.initialize_thought_embedding_to_normal:
1319
  self.end_embedding.data = torch.zeros_like(self.end_embedding.data)
@@ -1332,7 +1302,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1332
  else:
1333
  # convert to identity transform
1334
  def lambda_transform(cur_head):
1335
- # pdb.set_trace()
1336
  if cur_head.weight.data.shape[0] != cur_head.weight.data.shape[1]:
1337
  return torch.cat([
1338
  torch.eye(
@@ -1360,28 +1329,23 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1360
  self.talk_head[-1].weight.data = lambda_transform(self.talk_head[0])
1361
 
1362
  loss = None
1363
- prev_rm_tokens = None
1364
  cur_rm_tokens = None
1365
- prev_rm_logits = None
1366
  prev_sample_probs = None
1367
  did_skip_sampling = None
1368
  skip_sampling = None
1369
  sample_probs = None
1370
  hidden_states = None
1371
  logits = None
1372
- talk_kl_penalty = None
1373
  rm_logits = None
1374
  residual_logits = None
1375
  probabilities_2d = None
1376
  prev_probabilities_2d = None
1377
  policy_reward = None
1378
- logits_to_output = None
1379
  batch_size, seq_len = input_ids.shape
1380
  base_input_ids = input_ids.clone()
1381
  loss_list = []
1382
  dqn_loss_list = []
1383
  sampled_token_history = []
1384
- sample_probs_history = []
1385
  action_loglikelihoods_list = []
1386
 
1387
  temperature = self.temperature
@@ -1397,7 +1361,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1397
  if self.train_only_thinking_embedding:
1398
  base_embeddings = base_embeddings.detach()
1399
 
1400
- # # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1401
  fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
1402
  for ahead_idx in range(fwd_iters):
1403
  past_key_values_length = 0
@@ -1442,15 +1406,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1442
  base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
1443
  base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
1444
  attention_mask = base_attention_mask
1445
- # breakpoint()
1446
  elif attention_mask.dim() == 2:
1447
  if seq_len + past_key_values_length != attention_mask.shape[-1]:
1448
- # breakpoint()
1449
  attention_mask = torch.cat(
1450
  [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1451
  dim=-1
1452
  )
1453
- # # if the attention mask
1454
  attention_mask = _prepare_4d_causal_attention_mask(
1455
  attention_mask,
1456
  (batch_size, seq_len),
@@ -1460,7 +1421,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1460
  )
1461
 
1462
  outputs = self.model(
1463
- # input_ids=input_ids,
1464
  attention_mask=attention_mask,
1465
  position_ids=position_ids,
1466
  past_key_values=past_key_values,
@@ -1468,14 +1428,13 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1468
  use_cache=use_cache,
1469
  output_attentions=output_attentions,
1470
  output_hidden_states=output_hidden_states,
1471
- # output_router_logits=output_router_logits,
1472
  return_dict=return_dict,
1473
  )
1474
 
1475
  prev_hidden_states = hidden_states
1476
  hidden_states = outputs[0]
1477
- prev_rm_logits = rm_logits # for policy gradient
1478
- prev_rm_tokens = cur_rm_tokens # for policy gradient
1479
 
1480
  if ahead_idx == 0:
1481
  hidden_states_lm = hidden_states
@@ -1521,7 +1480,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1521
  assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1
1522
  if self.clever_residual:
1523
  if ahead_idx >= self.n_ahead - 1:
1524
- # get the logits shifted according to the current talk ahead
1525
  cur_base_logits = torch.cat([
1526
  base_logits[..., ahead_idx - self.n_ahead + 1:, :],
1527
  base_logits[..., :ahead_idx - self.n_ahead + 1, :]
@@ -1566,7 +1524,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1566
 
1567
  attempted = False
1568
  talk_loss_list = []
1569
- if self.original_mode or (self.n_ahead == 1) or (self.comparison_mode and ahead_idx == 0):# or (self.optimize_lm_head_only_at_start and ahead_idx == 0):
1570
  loss = None
1571
  attempted = True
1572
 
@@ -1597,7 +1555,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1597
 
1598
  if not attempted or self.comparison_mode:
1599
  rm_hidden_states = hidden_states
1600
- # print("Magnitude of RM hidden states before RM head", rm_hidden_states.norm())
1601
  rm_logits = apply_head(self.lm_head, rm_hidden_states, detach=self.optimize_lm_head_only_at_start)
1602
 
1603
  # don't allow it to predict the thinking token
@@ -1626,9 +1583,8 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1626
  probabilities_2d[:, override_token] = 1.0
1627
  skip_sampling = True
1628
  elif ahead_idx >= self.n_ahead - 1:
1629
- if labels is not None: # we're in the talk phase
1630
  cur_talk_n = ahead_idx - (self.n_ahead - 1) + 1
1631
- # print("Setting rm to labels", cur_talk_n, "during", ahead_idx)
1632
  shift_labels = labels[..., cur_talk_n:].contiguous().to(probabilities_2d.device)
1633
  padding = torch.full_like(
1634
  labels[..., :cur_talk_n],
@@ -1640,11 +1596,9 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1640
  [shift_labels, padding],
1641
  dim=-1
1642
  )
1643
-
1644
- # print((new_rm_tokens > self.vocab_size - 1).any().item())
1645
  new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
1646
 
1647
- # Now safely convert rm tokens to one-hot
1648
  probabilities_2d = F.one_hot(new_rm_tokens, num_classes=self.vocab_size).reshape(-1, self.vocab_size).to(probabilities_2d.dtype)
1649
  else:
1650
  continue
@@ -1704,7 +1658,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1704
  new_attention = original_attention
1705
  else:
1706
  original_attention = original_attention == attention_mask.max()
1707
- # because eye isn't implemented for BF16, we need to handle the case
1708
  if not attention_mask.dtype == torch.bfloat16:
1709
  new_attention = torch.eye(
1710
  seq_len, dtype=attention_mask.dtype, device=attention_mask.device
@@ -1742,9 +1695,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1742
  # if shift_labels.min() == self.tokenizer.pad_token_id:
1743
  shift_labels = torch.where(shift_labels == self.tokenizer.pad_token_id, -100, shift_labels)
1744
  unreduced_loss = loss_fct(shift_logits, shift_labels)
1745
- # print("Loss:", unreduced_loss.item()) # Print the loss before checking for NaN values
1746
  if torch.any(unreduced_loss != unreduced_loss):
1747
- # pdb.set_trace()
1748
  raise ValueError("NaN loss")
1749
  unreduced_loss = unreduced_loss.reshape(logits.shape[0], -1)
1750
  loss_list.append(unreduced_loss)
 
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
  """ PyTorch Quiet model."""
 
21
  import math
 
22
  import warnings
23
  from collections import defaultdict
24
  from typing import List, Optional, Tuple, Union
 
29
  from torch import nn
30
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
  from transformers.generation.utils import GenerationMixin
32
+ from transformers import AutoTokenizer
 
33
  import transformers
34
 
35
  from transformers.activations import ACT2FN
 
40
  from transformers.utils import (
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
 
 
43
  logging,
44
  replace_return_docstrings,
45
  )
 
235
  num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
236
  """
237
 
 
238
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
239
  if n_rep == 1:
240
  return hidden_states
 
326
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
327
 
328
  if past_key_value is not None:
329
+ cache_kwargs = {"sin": sin, "cos": cos}
330
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
331
 
332
  # repeat k/v heads if n_kv_heads < n_heads
 
371
  )
372
 
373
  attn_weights = attn_weights + attention_mask
374
+
 
375
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
376
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
377
  attn_output = torch.matmul(attn_weights, value_states)
 
844
  self.model = QuietModel(config)
845
  self.vocab_size = config.vocab_size
846
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
847
  self.max_thoughts = config.max_thoughts
848
  self.merged_lm_and_talk_heads = config.merged_lm_and_talk_heads
849
  self.use_concat_talk_head = config.use_concat_talk_head
850
  self.use_shallow_talk = config.use_shallow_talk
851
  self.use_complex_talk_head = config.use_complex_talk_head
852
  self.use_weighted_talk_head = config.use_weighted_talk_head
 
853
  assert not (self.use_weighted_talk_head and self.use_shallow_talk)
854
 
855
  self.n_ahead = 1
 
920
  self.thinking_threshold = 0.5
921
  self.thinking_usefulness_loss_weight = 1e-2
922
 
 
923
  self.use_thought_prefix = False
924
  self.use_reparam_for_thought_embeddings = False
925
  self.use_upper_triangular = False
 
927
  self.comparison_mode = False
928
  self.gumbel_detach = False
929
 
 
930
  self.eval_mode = False
931
 
932
  num_talk = 1
 
955
  # Add dropout regularization
956
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
957
 
 
958
  self.post_init()
959
 
960
  def get_input_embeddings(self):
 
1205
  n_passes_to_restore = self.n_passes
1206
  self.n_ahead_talk = 1
1207
  self.n_passes = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1208
  if input_ids.dim() == 1:
1209
  input_ids = input_ids.unsqueeze(0)
1210
  attention_mask = attention_mask.unsqueeze(0) if attention_mask is not None else None
 
1272
  self.start_token_id = self.tokenizer.bos_token_id
1273
  self.tokenizer_has_start_thought_token = False
1274
  elif self.use_start_thought_token:
 
1275
  base_start_id = self.tokenizer.encode(self.initial_start_token, add_special_tokens=False)[0]
1276
  if self.initialize_thought_embedding_to_normal:
1277
  self.start_embedding.data = torch.zeros_like(self.start_embedding.data)
 
1284
  self.end_token_id = self.tokenizer.eos_token_id
1285
  self.tokenizer_has_end_thought_token = False
1286
  elif self.use_end_thought_token:
 
1287
  base_end_id = self.tokenizer.encode(self.initial_end_token, add_special_tokens=False)[0]
1288
  if self.initialize_thought_embedding_to_normal:
1289
  self.end_embedding.data = torch.zeros_like(self.end_embedding.data)
 
1302
  else:
1303
  # convert to identity transform
1304
  def lambda_transform(cur_head):
 
1305
  if cur_head.weight.data.shape[0] != cur_head.weight.data.shape[1]:
1306
  return torch.cat([
1307
  torch.eye(
 
1329
  self.talk_head[-1].weight.data = lambda_transform(self.talk_head[0])
1330
 
1331
  loss = None
 
1332
  cur_rm_tokens = None
 
1333
  prev_sample_probs = None
1334
  did_skip_sampling = None
1335
  skip_sampling = None
1336
  sample_probs = None
1337
  hidden_states = None
1338
  logits = None
 
1339
  rm_logits = None
1340
  residual_logits = None
1341
  probabilities_2d = None
1342
  prev_probabilities_2d = None
1343
  policy_reward = None
 
1344
  batch_size, seq_len = input_ids.shape
1345
  base_input_ids = input_ids.clone()
1346
  loss_list = []
1347
  dqn_loss_list = []
1348
  sampled_token_history = []
 
1349
  action_loglikelihoods_list = []
1350
 
1351
  temperature = self.temperature
 
1361
  if self.train_only_thinking_embedding:
1362
  base_embeddings = base_embeddings.detach()
1363
 
1364
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1365
  fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
1366
  for ahead_idx in range(fwd_iters):
1367
  past_key_values_length = 0
 
1406
  base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
1407
  base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
1408
  attention_mask = base_attention_mask
 
1409
  elif attention_mask.dim() == 2:
1410
  if seq_len + past_key_values_length != attention_mask.shape[-1]:
 
1411
  attention_mask = torch.cat(
1412
  [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1413
  dim=-1
1414
  )
 
1415
  attention_mask = _prepare_4d_causal_attention_mask(
1416
  attention_mask,
1417
  (batch_size, seq_len),
 
1421
  )
1422
 
1423
  outputs = self.model(
 
1424
  attention_mask=attention_mask,
1425
  position_ids=position_ids,
1426
  past_key_values=past_key_values,
 
1428
  use_cache=use_cache,
1429
  output_attentions=output_attentions,
1430
  output_hidden_states=output_hidden_states,
 
1431
  return_dict=return_dict,
1432
  )
1433
 
1434
  prev_hidden_states = hidden_states
1435
  hidden_states = outputs[0]
1436
+ prev_rm_logits = rm_logits
1437
+ prev_rm_tokens = cur_rm_tokens
1438
 
1439
  if ahead_idx == 0:
1440
  hidden_states_lm = hidden_states
 
1480
  assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1
1481
  if self.clever_residual:
1482
  if ahead_idx >= self.n_ahead - 1:
 
1483
  cur_base_logits = torch.cat([
1484
  base_logits[..., ahead_idx - self.n_ahead + 1:, :],
1485
  base_logits[..., :ahead_idx - self.n_ahead + 1, :]
 
1524
 
1525
  attempted = False
1526
  talk_loss_list = []
1527
+ if self.original_mode or (self.n_ahead == 1) or (self.comparison_mode and ahead_idx == 0):
1528
  loss = None
1529
  attempted = True
1530
 
 
1555
 
1556
  if not attempted or self.comparison_mode:
1557
  rm_hidden_states = hidden_states
 
1558
  rm_logits = apply_head(self.lm_head, rm_hidden_states, detach=self.optimize_lm_head_only_at_start)
1559
 
1560
  # don't allow it to predict the thinking token
 
1583
  probabilities_2d[:, override_token] = 1.0
1584
  skip_sampling = True
1585
  elif ahead_idx >= self.n_ahead - 1:
1586
+ if labels is not None:
1587
  cur_talk_n = ahead_idx - (self.n_ahead - 1) + 1
 
1588
  shift_labels = labels[..., cur_talk_n:].contiguous().to(probabilities_2d.device)
1589
  padding = torch.full_like(
1590
  labels[..., :cur_talk_n],
 
1596
  [shift_labels, padding],
1597
  dim=-1
1598
  )
1599
+
 
1600
  new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
1601
 
 
1602
  probabilities_2d = F.one_hot(new_rm_tokens, num_classes=self.vocab_size).reshape(-1, self.vocab_size).to(probabilities_2d.dtype)
1603
  else:
1604
  continue
 
1658
  new_attention = original_attention
1659
  else:
1660
  original_attention = original_attention == attention_mask.max()
 
1661
  if not attention_mask.dtype == torch.bfloat16:
1662
  new_attention = torch.eye(
1663
  seq_len, dtype=attention_mask.dtype, device=attention_mask.device
 
1695
  # if shift_labels.min() == self.tokenizer.pad_token_id:
1696
  shift_labels = torch.where(shift_labels == self.tokenizer.pad_token_id, -100, shift_labels)
1697
  unreduced_loss = loss_fct(shift_logits, shift_labels)
 
1698
  if torch.any(unreduced_loss != unreduced_loss):
 
1699
  raise ValueError("NaN loss")
1700
  unreduced_loss = unreduced_loss.reshape(logits.shape[0], -1)
1701
  loss_list.append(unreduced_loss)