Update summarize.py
Browse files- summarize.py +15 -15
summarize.py
CHANGED
@@ -47,24 +47,24 @@ def summarize(ids, mask, model, tokenizer, model_arch, **kwargs):
|
|
47 |
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
|
48 |
|
49 |
if model_arch == 'LED':
|
50 |
-
|
51 |
# put global attention on <s> token
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
|
61 |
else:
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
summary = tokenizer.batch_decode(
|
69 |
summary_pred_ids.sequences,
|
70 |
skip_special_tokens=True,
|
|
|
47 |
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
|
48 |
|
49 |
if model_arch == 'LED':
|
50 |
+
global_attention_mask = torch.zeros_like(attention_mask)
|
51 |
# put global attention on <s> token
|
52 |
+
global_attention_mask[:, 0] = 1
|
53 |
+
summary_pred_ids = model.generate(
|
54 |
+
input_ids,
|
55 |
+
attention_mask=attention_mask,
|
56 |
+
global_attention_mask=global_attention_mask,
|
57 |
+
return_dict_in_generate=True,
|
58 |
+
**kwargs,
|
59 |
+
)
|
60 |
|
61 |
else:
|
62 |
+
summary_pred_ids = model.generate(
|
63 |
+
input_ids,
|
64 |
+
attention_mask=attention_mask,
|
65 |
+
return_dict_in_generate=True,
|
66 |
+
**kwargs,
|
67 |
+
)
|
68 |
summary = tokenizer.batch_decode(
|
69 |
summary_pred_ids.sequences,
|
70 |
skip_special_tokens=True,
|