Update README.md
Browse files
README.md
CHANGED
@@ -1 +1,22 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
```python
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
+
import torch
|
4 |
+
|
5 |
+
tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-scico')
|
6 |
+
model = AutoModelForSequenceClassification.from_pretrained('allenai/longformer-scico')
|
7 |
+
|
8 |
+
start_token = tokenizer.convert_tokens_to_ids("<m>")
|
9 |
+
end_token = tokenizer.convert_tokens_to_aids("</m>")
|
10 |
+
|
11 |
+
def get_global_attention(input_ids):
|
12 |
+
global_attention_mask = torch.zeros(input_ids.shape)
|
13 |
+
global_attention_mask[:, 0] = 1 # global attention to the CLS token
|
14 |
+
start = torch.nonzero(input_ids == start_token) # global attention to the <m> token
|
15 |
+
end = torch.nonzero(input_ids == end_token) # global attention to the </m> token
|
16 |
+
globs = torch.cat((start, end))
|
17 |
+
value = torch.ones(globs.shape[0])
|
18 |
+
global_attention_mask.index_put_(tuple(globs.t()), value)
|
19 |
+
return global_attention_mask
|
20 |
+
|
21 |
+
|
22 |
+
```
|