File size: 2,480 Bytes
c0705cc
 
 
 
19829c3
 
f53e66c
 
c0705cc
5eaf1fe
e7b9eb6
 
 
19829c3
e7b9eb6
5eaf1fe
19829c3
5eaf1fe
 
c7887b9
2d69ee0
b88f161
5eaf1fe
19829c3
 
5eaf1fe
19829c3
 
2d69ee0
 
19829c3
 
2d69ee0
19829c3
2d69ee0
19829c3
b88f161
50a1512
19829c3
 
 
 
c7887b9
b88f161
19829c3
 
 
 
 
 
 
 
 
 
 
 
 
 
b88f161
19829c3
 
50a1512
 
 
 
 
5b8fe7a
 
 
 
 
 
 
50a1512
5eaf1fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
---
language: zh
tags:
- roformer
- pytorch
- tf2.0
widget:
- text: "今天[MASK]很好,我想去公园玩!"
---
## 介绍
### tf版本 
https://github.com/ZhuiyiTechnology/roformer

### pytorch版本+tf2.0版本
https://github.com/JunnYu/RoFormer_pytorch

## pytorch使用
```python
import torch
from transformers import RoFormerForMaskedLM, RoFormerTokenizer

text = "今天[MASK]很好,我想去公园玩!"
tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base")
pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
pt_inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
    pt_outputs = pt_model(**pt_inputs).logits[0]
pt_outputs_sentence = "pytorch: "
for i, id in enumerate(tokenizer.encode(text)):
    if id == tokenizer.mask_token_id:
        tokens = tokenizer.convert_ids_to_tokens(pt_outputs[i].topk(k=5)[1])
        pt_outputs_sentence += "[" + "||".join(tokens) + "]"
    else:
        pt_outputs_sentence += "".join(
            tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
print(pt_outputs_sentence)
# pytorch: 今天[天气||天||阳光||太阳||空气]很好,我想去公园玩!
```

## tensorflow2.0使用
```python
import tensorflow as tf
from transformers import RoFormerTokenizer, TFRoFormerForMaskedLM
text = "今天[MASK]很好,我想去公园玩!"
tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base")
tf_model = TFRoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
tf_inputs = tokenizer(text, return_tensors="tf")
tf_outputs = tf_model(**tf_inputs, training=False).logits[0]
tf_outputs_sentence = "tf2.0: "
for i, id in enumerate(tokenizer.encode(text)):
    if id == tokenizer.mask_token_id:
        tokens = tokenizer.convert_ids_to_tokens(
            tf.math.top_k(tf_outputs[i], k=5)[1])
        tf_outputs_sentence += "[" + "||".join(tokens) + "]"
    else:
        tf_outputs_sentence += "".join(
            tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
print(tf_outputs_sentence)
# tf2.0: 今天[天气||天||阳光||太阳||空气]很好,我想去公园玩!
```

## 引用

Bibtex:

```tex
@misc{su2021roformer,
      title={RoFormer: Enhanced Transformer with Rotary Position Embedding}, 
      author={Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
      year={2021},
      eprint={2104.09864},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
```