ccdv commited on
Commit
00743f3
1 Parent(s): 1d0db05

block_stride + fixes + readme

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. modeling_lsg_bart.py +6 -215
README.md CHANGED
@@ -45,6 +45,7 @@ It achieves the following results on the test set:
45
  | Length | Global tokens | Fine-tuning | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
46
  |:------ |:------------- |:----------- |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
47
  | 16384 | 64 | Full | 256 | 0 | 768 | 48.74 | 20.88 | 28.50 | 44.23 |
 
48
  | 16384 | 64 | Global only | 256 | 0 | 768 | 48.08 | 20.42 | 28.00 | 43.65 |
49
  | 16384 | 1 | None | 256 | 0 | 768 | 47.03 | 20.19 | 28.26 | 42.69 |
50
 
 
45
  | Length | Global tokens | Fine-tuning | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
46
  |:------ |:------------- |:----------- |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
47
  | 16384 | 64 | Full | 256 | 0 | 768 | 48.74 | 20.88 | 28.50 | 44.23 |
48
+ | 16384 | 1 | Full | 256 | 0 | 768 | 48.66 | 20.92 | 28.50 | 44.18 |
49
  | 16384 | 64 | Global only | 256 | 0 | 768 | 48.08 | 20.42 | 28.00 | 43.65 |
50
  | 16384 | 1 | None | 256 | 0 | 768 | 47.03 | 20.19 | 28.26 | 42.69 |
51
 
modeling_lsg_bart.py CHANGED
@@ -1167,25 +1167,13 @@ class LSGBartEncoder(LSGBartPretrainedModel):
1167
  pad = t % self.block_size
1168
 
1169
  # Check if t is multiple of block_size and pad
1170
- if t > b and pad > 0:
1171
  pad_length = self.block_size - pad
1172
  if input_ids is not None:
1173
  input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
1174
  else:
1175
  inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
1176
  attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
1177
-
1178
- # else adaptive sequence length
1179
- elif self.adaptive:
1180
- # Get last non zero mask index
1181
- s = int(attention_mask.cumsum(dim=-1).argmax(dim=-1).max()) + 1
1182
- if s < t and self.block_size is not None:
1183
- s = max(2, s // self.block_size + 1) * self.block_size if s > b else s
1184
- if input_ids is not None:
1185
- input_ids = input_ids[:, :s]
1186
- else:
1187
- inputs_embeds = inputs_embeds[:, :s]
1188
- attention_mask = attention_mask[:, :s]
1189
 
1190
  n, t_ = attention_mask.size()
1191
 
@@ -1211,9 +1199,7 @@ class LSGBartEncoder(LSGBartPretrainedModel):
1211
  offset = 0
1212
 
1213
  # Adapt sequence to initial shape
1214
- if diff > 0:
1215
- context = torch.nn.functional.pad(context.transpose(-1, -2), pad=(0, diff), value=0).transpose(-1, -2)
1216
- elif diff < 0:
1217
  context = context[:, :t + offset]
1218
 
1219
  if return_dict:
@@ -1325,7 +1311,7 @@ class LSGBartEncoder(LSGBartPretrainedModel):
1325
  )
1326
 
1327
 
1328
- class LSGBartDecoder(LSGBartPretrainedModel):
1329
  """
1330
  Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
1331
  Args:
@@ -1334,8 +1320,9 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1334
  """
1335
 
1336
  def __init__(self, config, embed_tokens=None):
1337
-
1338
- super().__init__(config)
 
1339
  self.dropout = config.dropout
1340
  self.layerdrop = config.decoder_layerdrop
1341
  self.padding_idx = config.pad_token_id
@@ -1360,202 +1347,6 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1360
  # Initialize weights and apply final processing
1361
  self.post_init()
1362
 
1363
- def get_input_embeddings(self):
1364
- return self.embed_tokens
1365
-
1366
- def set_input_embeddings(self, value):
1367
- self.embed_tokens = value
1368
-
1369
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
1370
- # create causal mask
1371
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1372
- combined_attention_mask = None
1373
- if input_shape[-1] > 1:
1374
- combined_attention_mask = _make_causal_mask(
1375
- input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
1376
- ).to(self.device)
1377
-
1378
- if attention_mask is not None:
1379
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1380
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
1381
- combined_attention_mask = (
1382
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
1383
- )
1384
-
1385
- return combined_attention_mask
1386
-
1387
- def resize_inputs(self, inputs_embeds, attention_mask):
1388
- pad = 0
1389
-
1390
- max_len = int(attention_mask.sum(dim=-1).max())
1391
- pad = attention_mask.size()[-1] - max_len
1392
- inputs_embeds = inputs_embeds[:, :max_len]
1393
- attention_mask = attention_mask[..., :max_len]
1394
- return pad, inputs_embeds, attention_mask
1395
-
1396
- def forward(
1397
- self,
1398
- input_ids=None,
1399
- attention_mask=None,
1400
- encoder_hidden_states=None,
1401
- encoder_attention_mask=None,
1402
- head_mask=None,
1403
- cross_attn_head_mask=None,
1404
- past_key_values=None,
1405
- inputs_embeds=None,
1406
- use_cache=None,
1407
- output_attentions=None,
1408
- output_hidden_states=None,
1409
- return_dict=None,
1410
- ):
1411
-
1412
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1413
- output_hidden_states = (
1414
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1415
- )
1416
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1417
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1418
-
1419
- # retrieve input_ids and inputs_embeds
1420
- if input_ids is not None and inputs_embeds is not None:
1421
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1422
- elif input_ids is not None:
1423
- input_shape = input_ids.size()
1424
- input_ids = input_ids.view(-1, input_shape[-1])
1425
- elif inputs_embeds is not None:
1426
- input_shape = inputs_embeds.size()[:-1]
1427
- else:
1428
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1429
-
1430
- # past_key_values_length
1431
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1432
-
1433
- if inputs_embeds is None:
1434
- inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1435
-
1436
- # Resize to reduce computation
1437
- pad = 0
1438
- if self.adaptive:
1439
- if attention_mask is not None:
1440
- pad, inputs_embeds, attention_mask = self.resize_inputs(inputs_embeds, attention_mask)
1441
- input_shape = inputs_embeds.size()[:-1]
1442
- if encoder_attention_mask is not None:
1443
- _, encoder_hidden_states, encoder_attention_mask = self.resize_inputs(encoder_hidden_states, encoder_attention_mask)
1444
-
1445
- attention_mask = self._prepare_decoder_attention_mask(
1446
- attention_mask, input_shape, inputs_embeds, past_key_values_length
1447
- )
1448
-
1449
- # expand encoder attention mask
1450
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
1451
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1452
- encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
1453
-
1454
- # embed positions
1455
- positions = self.embed_positions(input_shape, past_key_values_length)
1456
-
1457
- hidden_states = inputs_embeds + positions
1458
- hidden_states = self.layernorm_embedding(hidden_states)
1459
-
1460
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1461
-
1462
- # decoder layers
1463
- all_hidden_states = () if output_hidden_states else None
1464
- all_self_attns = () if output_attentions else None
1465
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1466
- next_decoder_cache = () if use_cache else None
1467
-
1468
- # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1469
- for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
1470
- if attn_mask is not None:
1471
- if attn_mask.size()[0] != (len(self.layers)):
1472
- raise ValueError(
1473
- "The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
1474
- )
1475
-
1476
- for idx, decoder_layer in enumerate(self.layers):
1477
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1478
- if output_hidden_states:
1479
- all_hidden_states += (hidden_states,)
1480
- dropout_probability = random.uniform(0, 1)
1481
- if self.training and (dropout_probability < self.layerdrop):
1482
- continue
1483
-
1484
- past_key_value = past_key_values[idx] if past_key_values is not None else None
1485
-
1486
- if self.gradient_checkpointing and self.training:
1487
-
1488
- if use_cache:
1489
- logger.warning(
1490
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1491
- )
1492
- use_cache = False
1493
-
1494
- def create_custom_forward(module):
1495
- def custom_forward(*inputs):
1496
- # None for past_key_value
1497
- return module(*inputs, output_attentions, use_cache)
1498
-
1499
- return custom_forward
1500
-
1501
- layer_outputs = torch.utils.checkpoint.checkpoint(
1502
- create_custom_forward(decoder_layer),
1503
- hidden_states,
1504
- attention_mask,
1505
- encoder_hidden_states,
1506
- encoder_attention_mask,
1507
- head_mask[idx] if head_mask is not None else None,
1508
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
1509
- None,
1510
- )
1511
- else:
1512
-
1513
- layer_outputs = decoder_layer(
1514
- hidden_states,
1515
- attention_mask=attention_mask,
1516
- encoder_hidden_states=encoder_hidden_states,
1517
- encoder_attention_mask=encoder_attention_mask,
1518
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1519
- cross_attn_layer_head_mask=(
1520
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
1521
- ),
1522
- past_key_value=past_key_value,
1523
- output_attentions=output_attentions,
1524
- use_cache=use_cache,
1525
- )
1526
- hidden_states = layer_outputs[0]
1527
-
1528
- if use_cache:
1529
- next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1530
-
1531
- if output_attentions:
1532
- all_self_attns += (layer_outputs[1],)
1533
-
1534
- if encoder_hidden_states is not None:
1535
- all_cross_attentions += (layer_outputs[2],)
1536
-
1537
- # Resize to original shape
1538
- hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), pad=(0, pad), value=0).transpose(-1, -2)
1539
-
1540
- # add hidden states from the last decoder layer
1541
- if output_hidden_states:
1542
- all_hidden_states += (hidden_states,)
1543
-
1544
- next_cache = next_decoder_cache if use_cache else None
1545
- if not return_dict:
1546
- return tuple(
1547
- v
1548
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
1549
- if v is not None
1550
- )
1551
- return BaseModelOutputWithPastAndCrossAttentions(
1552
- last_hidden_state=hidden_states,
1553
- past_key_values=next_cache,
1554
- hidden_states=all_hidden_states,
1555
- attentions=all_self_attns,
1556
- cross_attentions=all_cross_attentions,
1557
- )
1558
-
1559
 
1560
  class LSGBartModel(LSGBartPretrainedModel):
1561
 
 
1167
  pad = t % self.block_size
1168
 
1169
  # Check if t is multiple of block_size and pad
1170
+ if self.adaptive and t > b and pad > 0:
1171
  pad_length = self.block_size - pad
1172
  if input_ids is not None:
1173
  input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
1174
  else:
1175
  inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
1176
  attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
 
 
 
 
 
 
 
 
 
 
 
 
1177
 
1178
  n, t_ = attention_mask.size()
1179
 
 
1199
  offset = 0
1200
 
1201
  # Adapt sequence to initial shape
1202
+ if diff < 0:
 
 
1203
  context = context[:, :t + offset]
1204
 
1205
  if return_dict:
 
1311
  )
1312
 
1313
 
1314
+ class LSGBartDecoder(BartDecoder, LSGBartPretrainedModel):
1315
  """
1316
  Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
1317
  Args:
 
1320
  """
1321
 
1322
  def __init__(self, config, embed_tokens=None):
1323
+
1324
+ LSGBartPretrainedModel.__init__(self, config)
1325
+
1326
  self.dropout = config.dropout
1327
  self.layerdrop = config.decoder_layerdrop
1328
  self.padding_idx = config.pad_token_id
 
1347
  # Initialize weights and apply final processing
1348
  self.post_init()
1349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1350
 
1351
  class LSGBartModel(LSGBartPretrainedModel):
1352