ccdv commited on
Commit
d6ca5e6
1 Parent(s): 8e0db59

block_stride + fixes + readme

Browse files
Files changed (1) hide show
  1. modeling_lsg_bart.py +24 -229
modeling_lsg_bart.py CHANGED
@@ -54,15 +54,15 @@ class LSGBartConfig(BartConfig):
54
  self.sparsity_factor = sparsity_factor
55
  self.sparsity_type = sparsity_type
56
 
57
- if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride"]:
58
  logger.warning(
59
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride'], setting sparsity_type=None, computation will skip sparse attention")
60
  self.sparsity_type = None
61
 
62
- if self.sparsity_type == "stride":
63
  if self.sparsity_factor > self.encoder_attention_heads:
64
  logger.warning(
65
- "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride sparsity"
66
  )
67
 
68
  if self.num_global_tokens < 1:
@@ -412,6 +412,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
412
  "pooling": self.get_sparse_tokens_with_pooling,
413
  "lsh": self.get_sparse_tokens_with_lsh,
414
  "stride": self.get_sparse_tokens_with_stride,
 
415
  }
416
 
417
  self.sparsity_type = config.sparsity_type
@@ -480,29 +481,32 @@ class LSGBartEncoderAttention(BaseSelfAttention):
480
  sparse_idx = sparse_idx.reshape(1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1)
481
  sparse_idx = sparse_idx.expand(n, h, -1, 1)
482
 
483
- """
484
- t, b = self.block_size, t // self.block_size
485
- sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
486
- sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1, 1)
487
- sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
488
- sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
 
 
 
 
 
 
489
 
490
-
491
  t, b = self.block_size, t // self.block_size
492
  sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device)
493
  sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + torch.arange(h, device=keys.device).reshape(1, h, 1, 1, 1) * (t // self.sparsity_factor)
494
  sparse_idx = (sparse_idx % t)
495
- #sparse_idx[..., -t//2:, :] = (sparse_idx[..., -t//2:, :] + t//2) % t
496
  sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
497
  sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
498
- """
499
 
500
  keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
501
  values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
502
  mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
503
 
504
  return keys, values, mask
505
-
506
  def get_sparse_tokens_with_lsh(self, keys, values, mask):
507
 
508
  if self.sparsity_factor == 1:
@@ -1163,25 +1167,13 @@ class LSGBartEncoder(LSGBartPretrainedModel):
1163
  pad = t % self.block_size
1164
 
1165
  # Check if t is multiple of block_size and pad
1166
- if t > b and pad > 0:
1167
  pad_length = self.block_size - pad
1168
  if input_ids is not None:
1169
  input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
1170
  else:
1171
  inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
1172
  attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
1173
-
1174
- # else adaptive sequence length
1175
- elif self.adaptive:
1176
- # Get last non zero mask index
1177
- s = int(attention_mask.cumsum(dim=-1).argmax(dim=-1).max()) + 1
1178
- if s < t and self.block_size is not None:
1179
- s = max(2, s // self.block_size + 1) * self.block_size if s > b else s
1180
- if input_ids is not None:
1181
- input_ids = input_ids[:, :s]
1182
- else:
1183
- inputs_embeds = inputs_embeds[:, :s]
1184
- attention_mask = attention_mask[:, :s]
1185
 
1186
  n, t_ = attention_mask.size()
1187
 
@@ -1207,9 +1199,7 @@ class LSGBartEncoder(LSGBartPretrainedModel):
1207
  offset = 0
1208
 
1209
  # Adapt sequence to initial shape
1210
- if diff > 0:
1211
- context = torch.nn.functional.pad(context.transpose(-1, -2), pad=(0, diff), value=0).transpose(-1, -2)
1212
- elif diff < 0:
1213
  context = context[:, :t + offset]
1214
 
1215
  if return_dict:
@@ -1321,7 +1311,7 @@ class LSGBartEncoder(LSGBartPretrainedModel):
1321
  )
1322
 
1323
 
1324
- class LSGBartDecoder(LSGBartPretrainedModel):
1325
  """
1326
  Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
1327
  Args:
@@ -1330,8 +1320,9 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1330
  """
1331
 
1332
  def __init__(self, config, embed_tokens=None):
1333
-
1334
- super().__init__(config)
 
1335
  self.dropout = config.dropout
1336
  self.layerdrop = config.decoder_layerdrop
1337
  self.padding_idx = config.pad_token_id
@@ -1356,202 +1347,6 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1356
  # Initialize weights and apply final processing
1357
  self.post_init()
1358
 
1359
- def get_input_embeddings(self):
1360
- return self.embed_tokens
1361
-
1362
- def set_input_embeddings(self, value):
1363
- self.embed_tokens = value
1364
-
1365
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
1366
- # create causal mask
1367
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1368
- combined_attention_mask = None
1369
- if input_shape[-1] > 1:
1370
- combined_attention_mask = _make_causal_mask(
1371
- input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
1372
- ).to(self.device)
1373
-
1374
- if attention_mask is not None:
1375
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1376
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
1377
- combined_attention_mask = (
1378
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
1379
- )
1380
-
1381
- return combined_attention_mask
1382
-
1383
- def resize_inputs(self, inputs_embeds, attention_mask):
1384
- pad = 0
1385
-
1386
- max_len = int(attention_mask.sum(dim=-1).max())
1387
- pad = attention_mask.size()[-1] - max_len
1388
- inputs_embeds = inputs_embeds[:, :max_len]
1389
- attention_mask = attention_mask[..., :max_len]
1390
- return pad, inputs_embeds, attention_mask
1391
-
1392
- def forward(
1393
- self,
1394
- input_ids=None,
1395
- attention_mask=None,
1396
- encoder_hidden_states=None,
1397
- encoder_attention_mask=None,
1398
- head_mask=None,
1399
- cross_attn_head_mask=None,
1400
- past_key_values=None,
1401
- inputs_embeds=None,
1402
- use_cache=None,
1403
- output_attentions=None,
1404
- output_hidden_states=None,
1405
- return_dict=None,
1406
- ):
1407
-
1408
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1409
- output_hidden_states = (
1410
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1411
- )
1412
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1413
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1414
-
1415
- # retrieve input_ids and inputs_embeds
1416
- if input_ids is not None and inputs_embeds is not None:
1417
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1418
- elif input_ids is not None:
1419
- input_shape = input_ids.size()
1420
- input_ids = input_ids.view(-1, input_shape[-1])
1421
- elif inputs_embeds is not None:
1422
- input_shape = inputs_embeds.size()[:-1]
1423
- else:
1424
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1425
-
1426
- # past_key_values_length
1427
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1428
-
1429
- if inputs_embeds is None:
1430
- inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1431
-
1432
- # Resize to reduce computation
1433
- pad = 0
1434
- if self.adaptive:
1435
- if attention_mask is not None:
1436
- pad, inputs_embeds, attention_mask = self.resize_inputs(inputs_embeds, attention_mask)
1437
- input_shape = inputs_embeds.size()[:-1]
1438
- if encoder_attention_mask is not None:
1439
- _, encoder_hidden_states, encoder_attention_mask = self.resize_inputs(encoder_hidden_states, encoder_attention_mask)
1440
-
1441
- attention_mask = self._prepare_decoder_attention_mask(
1442
- attention_mask, input_shape, inputs_embeds, past_key_values_length
1443
- )
1444
-
1445
- # expand encoder attention mask
1446
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
1447
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1448
- encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
1449
-
1450
- # embed positions
1451
- positions = self.embed_positions(input_shape, past_key_values_length)
1452
-
1453
- hidden_states = inputs_embeds + positions
1454
- hidden_states = self.layernorm_embedding(hidden_states)
1455
-
1456
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1457
-
1458
- # decoder layers
1459
- all_hidden_states = () if output_hidden_states else None
1460
- all_self_attns = () if output_attentions else None
1461
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1462
- next_decoder_cache = () if use_cache else None
1463
-
1464
- # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1465
- for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
1466
- if attn_mask is not None:
1467
- if attn_mask.size()[0] != (len(self.layers)):
1468
- raise ValueError(
1469
- "The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
1470
- )
1471
-
1472
- for idx, decoder_layer in enumerate(self.layers):
1473
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1474
- if output_hidden_states:
1475
- all_hidden_states += (hidden_states,)
1476
- dropout_probability = random.uniform(0, 1)
1477
- if self.training and (dropout_probability < self.layerdrop):
1478
- continue
1479
-
1480
- past_key_value = past_key_values[idx] if past_key_values is not None else None
1481
-
1482
- if self.gradient_checkpointing and self.training:
1483
-
1484
- if use_cache:
1485
- logger.warning(
1486
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1487
- )
1488
- use_cache = False
1489
-
1490
- def create_custom_forward(module):
1491
- def custom_forward(*inputs):
1492
- # None for past_key_value
1493
- return module(*inputs, output_attentions, use_cache)
1494
-
1495
- return custom_forward
1496
-
1497
- layer_outputs = torch.utils.checkpoint.checkpoint(
1498
- create_custom_forward(decoder_layer),
1499
- hidden_states,
1500
- attention_mask,
1501
- encoder_hidden_states,
1502
- encoder_attention_mask,
1503
- head_mask[idx] if head_mask is not None else None,
1504
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
1505
- None,
1506
- )
1507
- else:
1508
-
1509
- layer_outputs = decoder_layer(
1510
- hidden_states,
1511
- attention_mask=attention_mask,
1512
- encoder_hidden_states=encoder_hidden_states,
1513
- encoder_attention_mask=encoder_attention_mask,
1514
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1515
- cross_attn_layer_head_mask=(
1516
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
1517
- ),
1518
- past_key_value=past_key_value,
1519
- output_attentions=output_attentions,
1520
- use_cache=use_cache,
1521
- )
1522
- hidden_states = layer_outputs[0]
1523
-
1524
- if use_cache:
1525
- next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1526
-
1527
- if output_attentions:
1528
- all_self_attns += (layer_outputs[1],)
1529
-
1530
- if encoder_hidden_states is not None:
1531
- all_cross_attentions += (layer_outputs[2],)
1532
-
1533
- # Resize to original shape
1534
- hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), pad=(0, pad), value=0).transpose(-1, -2)
1535
-
1536
- # add hidden states from the last decoder layer
1537
- if output_hidden_states:
1538
- all_hidden_states += (hidden_states,)
1539
-
1540
- next_cache = next_decoder_cache if use_cache else None
1541
- if not return_dict:
1542
- return tuple(
1543
- v
1544
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
1545
- if v is not None
1546
- )
1547
- return BaseModelOutputWithPastAndCrossAttentions(
1548
- last_hidden_state=hidden_states,
1549
- past_key_values=next_cache,
1550
- hidden_states=all_hidden_states,
1551
- attentions=all_self_attns,
1552
- cross_attentions=all_cross_attentions,
1553
- )
1554
-
1555
 
1556
  class LSGBartModel(LSGBartPretrainedModel):
1557
 
 
54
  self.sparsity_factor = sparsity_factor
55
  self.sparsity_type = sparsity_type
56
 
57
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
58
  logger.warning(
59
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], setting sparsity_type=None, computation will skip sparse attention")
60
  self.sparsity_type = None
61
 
62
+ if self.sparsity_type in ["stride", "block_stride"]:
63
  if self.sparsity_factor > self.encoder_attention_heads:
64
  logger.warning(
65
+ "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
66
  )
67
 
68
  if self.num_global_tokens < 1:
 
412
  "pooling": self.get_sparse_tokens_with_pooling,
413
  "lsh": self.get_sparse_tokens_with_lsh,
414
  "stride": self.get_sparse_tokens_with_stride,
415
+ "block_stride": self.get_sparse_tokens_with_block_stride,
416
  }
417
 
418
  self.sparsity_type = config.sparsity_type
 
481
  sparse_idx = sparse_idx.reshape(1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1)
482
  sparse_idx = sparse_idx.expand(n, h, -1, 1)
483
 
484
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
485
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
486
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
487
+
488
+ return keys, values, mask
489
+
490
+ def get_sparse_tokens_with_block_stride(self, keys, values, mask):
491
+
492
+ if self.sparsity_factor == 1:
493
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
494
+
495
+ n, h, t, d = keys.size()
496
 
 
497
  t, b = self.block_size, t // self.block_size
498
  sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device)
499
  sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + torch.arange(h, device=keys.device).reshape(1, h, 1, 1, 1) * (t // self.sparsity_factor)
500
  sparse_idx = (sparse_idx % t)
 
501
  sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
502
  sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
 
503
 
504
  keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
505
  values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
506
  mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
507
 
508
  return keys, values, mask
509
+
510
  def get_sparse_tokens_with_lsh(self, keys, values, mask):
511
 
512
  if self.sparsity_factor == 1:
 
1167
  pad = t % self.block_size
1168
 
1169
  # Check if t is multiple of block_size and pad
1170
+ if self.adaptive and t > b and pad > 0:
1171
  pad_length = self.block_size - pad
1172
  if input_ids is not None:
1173
  input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
1174
  else:
1175
  inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
1176
  attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
 
 
 
 
 
 
 
 
 
 
 
 
1177
 
1178
  n, t_ = attention_mask.size()
1179
 
 
1199
  offset = 0
1200
 
1201
  # Adapt sequence to initial shape
1202
+ if diff < 0:
 
 
1203
  context = context[:, :t + offset]
1204
 
1205
  if return_dict:
 
1311
  )
1312
 
1313
 
1314
+ class LSGBartDecoder(BartDecoder, LSGBartPretrainedModel):
1315
  """
1316
  Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
1317
  Args:
 
1320
  """
1321
 
1322
  def __init__(self, config, embed_tokens=None):
1323
+
1324
+ LSGBartPretrainedModel.__init__(self, config)
1325
+
1326
  self.dropout = config.dropout
1327
  self.layerdrop = config.decoder_layerdrop
1328
  self.padding_idx = config.pad_token_id
 
1347
  # Initialize weights and apply final processing
1348
  self.post_init()
1349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1350
 
1351
  class LSGBartModel(LSGBartPretrainedModel):
1352