Yanisadel commited on
Commit
02765e9
·
1 Parent(s): c7e0039

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +4 -4
chatNT.py CHANGED
@@ -975,7 +975,7 @@ class TorchGptDecoder(nn.Module):
975
  self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None
976
  ) -> torch.Tensor:
977
  if attention_mask is None:
978
- attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
979
  for layer in self.layers:
980
  embeddings = layer(embeddings, attention_mask)
981
 
@@ -985,7 +985,7 @@ class TorchGptDecoder(nn.Module):
985
  self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None
986
  ) -> dict[str, torch.Tensor]:
987
  if attention_mask is None:
988
- attention_mask = build_causal_attention_mask(1, token_ids.shape[1])
989
 
990
  tokens_embeddings = self.token_embed(token_ids)
991
 
@@ -1127,7 +1127,7 @@ def get_activation_fn(activation_name: str): # type: ignore
1127
  return activations.get(activation_name, nn.functional.relu)
1128
 
1129
 
1130
- def build_causal_attention_mask(batch_size: int, seq_len: int) -> torch.Tensor:
1131
  """
1132
  Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
1133
  to an attention layer.
@@ -1139,7 +1139,7 @@ def build_causal_attention_mask(batch_size: int, seq_len: int) -> torch.Tensor:
1139
  Returns:
1140
  Batch of causal masks.
1141
  """
1142
- mask = torch.ones((batch_size, 1, seq_len, seq_len))
1143
  causal_mask = torch.tril(mask)
1144
  return causal_mask
1145
 
 
975
  self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None
976
  ) -> torch.Tensor:
977
  if attention_mask is None:
978
+ attention_mask = build_causal_attention_mask(1, embeddings.shape[1], device=embeddings.device)
979
  for layer in self.layers:
980
  embeddings = layer(embeddings, attention_mask)
981
 
 
985
  self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None
986
  ) -> dict[str, torch.Tensor]:
987
  if attention_mask is None:
988
+ attention_mask = build_causal_attention_mask(1, token_ids.shape[1], device=token_ids.device)
989
 
990
  tokens_embeddings = self.token_embed(token_ids)
991
 
 
1127
  return activations.get(activation_name, nn.functional.relu)
1128
 
1129
 
1130
+ def build_causal_attention_mask(batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
1131
  """
1132
  Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
1133
  to an attention layer.
 
1139
  Returns:
1140
  Batch of causal masks.
1141
  """
1142
+ mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device)
1143
  causal_mask = torch.tril(mask)
1144
  return causal_mask
1145