Blaise-g commited on
Commit
b59f31f
Β·
1 Parent(s): dbe8965

Update summarize.py

Browse files
Files changed (1) hide show
  1. summarize.py +15 -15
summarize.py CHANGED
@@ -47,24 +47,24 @@ def summarize(ids, mask, model, tokenizer, model_arch, **kwargs):
47
  attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
48
 
49
  if model_arch == 'LED':
50
- global_attention_mask = torch.zeros_like(attention_mask)
51
  # put global attention on <s> token
52
- global_attention_mask[:, 0] = 1
53
- summary_pred_ids = model.generate(
54
- input_ids,
55
- attention_mask=attention_mask,
56
- global_attention_mask=global_attention_mask,
57
- return_dict_in_generate=True,
58
- **kwargs,
59
- )
60
 
61
  else:
62
- summary_pred_ids = model.generate(
63
- input_ids,
64
- attention_mask=attention_mask,
65
- return_dict_in_generate=True,
66
- **kwargs,
67
- )
68
  summary = tokenizer.batch_decode(
69
  summary_pred_ids.sequences,
70
  skip_special_tokens=True,
 
47
  attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
48
 
49
  if model_arch == 'LED':
50
+ global_attention_mask = torch.zeros_like(attention_mask)
51
  # put global attention on <s> token
52
+ global_attention_mask[:, 0] = 1
53
+ summary_pred_ids = model.generate(
54
+ input_ids,
55
+ attention_mask=attention_mask,
56
+ global_attention_mask=global_attention_mask,
57
+ return_dict_in_generate=True,
58
+ **kwargs,
59
+ )
60
 
61
  else:
62
+ summary_pred_ids = model.generate(
63
+ input_ids,
64
+ attention_mask=attention_mask,
65
+ return_dict_in_generate=True,
66
+ **kwargs,
67
+ )
68
  summary = tokenizer.batch_decode(
69
  summary_pred_ids.sequences,
70
  skip_special_tokens=True,