|
--- |
|
language: zh |
|
tags: |
|
- roformer-v2 |
|
- pytorch |
|
- tf2.0 |
|
inference: False |
|
--- |
|
## 介绍 |
|
### tf版本 |
|
https://github.com/ZhuiyiTechnology/roformer-v2 |
|
|
|
### pytorch版本+tf2.0版本 |
|
https://github.com/JunnYu/RoFormer_pytorch |
|
|
|
### 安装 |
|
- pip install roformer==0.4.3 |
|
|
|
## 评测对比 |
|
### CLUE-dev榜单分类任务结果,base+large版本。 |
|
|
|
| | iflytek | tnews | afqmc | cmnli | ocnli | wsc | csl | |
|
| :-----: | :-----: | :---: | :---: | :---: | :---: | :---: | :---: | |
|
| BERT | 60.06 | 56.80 | 72.41 | 79.56 | 73.93 | 78.62 | 83.93 | |
|
| RoBERTa | 60.64 | 58.06 | 74.05 | 81.24 | 76.00 | 87.50 | 84.50 | |
|
| RoFormer | 60.91 | 57.54 | 73.52 | 80.92 | 76.07 | 86.84 | 84.63 | |
|
| RoFormerV2<sup>*</sup> | 60.87 | 56.54 | 72.75 | 80.34 | 75.36 | 80.92 | 84.67 | |
|
| GAU-α | 61.41 | 57.76 | 74.17 | 81.82 | 75.86 | 79.93 | 85.67 | |
|
| RoFormer-pytorch(本仓库代码) | 60.60 | 57.51 | 74.44 | 80.79 | 75.67 | 86.84 | 84.77 | |
|
| RoFormerV2-pytorch(本仓库代码) | **62.87** | 59.03 | **76.20** | 80.85 | 79.73 | 87.82 | **91.87** | |
|
| GAU-α-pytorch(Adafactor) | 61.18 | 57.52 | 73.42 | 80.91 | 75.69 | 80.59 | 85.5 | |
|
| GAU-α-pytorch(AdamW wd0.01 warmup0.1) | 60.68 | 57.95 | 73.08 | 81.02 | 75.36 | 81.25 | 83.93 | |
|
| RoFormerV2-large-pytorch(本仓库代码) | 61.75 | **59.21** | 76.14 | 82.35 | **81.73** | **91.45** | 91.5 | |
|
| Chinesebert-large-pytorch | 61.25 | 58.67 | 74.70 | **82.65** | 79.63 | 87.83 | 84.97 | |
|
|
|
|
|
### CLUE-1.0-test榜单分类任务结果,base+large版本。 |
|
|
|
| | iflytek | tnews | afqmc | cmnli | ocnli | wsc | csl | |
|
| :-----: | :-----: | :---: | :---: | :---: | :---: | :---: | :---: | |
|
| RoFormer-pytorch(本仓库代码) | 59.54 | 57.34 | 74.46 | 80.23 | 73.67 | 80.69 | 84.57 | |
|
| RoFormerV2-pytorch(本仓库代码) | **63.15** | 58.24 | 75.42 | 80.59 | 74.17 | 83.79 | 83.73 | |
|
| GAU-α-pytorch(Adafactor) | 61.38 | 57.08 | 74.05 | 80.37 | 73.53 | 74.83 | **85.6** | |
|
| GAU-α-pytorch(AdamW wd0.01 warmup0.1) | 60.54 | 57.67 | 72.44 | 80.32 | 72.97 | 76.55 | 84.13 | |
|
| RoFormerV2-large-pytorch(本仓库代码) | 61.85 | **59.13** | **76.38** | 80.97 | 76.23 | **85.86** | 84.33 | |
|
| Chinesebert-large-pytorch | 61.54 | 58.57 | 74.8 | **81.94** | **76.93** | 79.66 | 85.1 | |
|
|
|
### 注: |
|
- 其中RoFormerV2<sup>*</sup>表示的是未进行多任务学习的RoFormerV2模型,该模型苏神并未开源,感谢苏神的提醒。 |
|
- 其中不带有pytorch后缀结果都是从[GAU-alpha](https://github.com/ZhuiyiTechnology/GAU-alpha)仓库复制过来的。 |
|
- 其中带有pytorch后缀的结果都是自己训练得出的。 |
|
- 苏神代码中拿了cls标签后直接进行了分类,而本仓库使用了如下的分类头,多了2个dropout,1个dense,1个relu激活。 |
|
|
|
```python |
|
class RoFormerClassificationHead(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
self.config = config |
|
|
|
def forward(self, features, **kwargs): |
|
x = features[:, 0, :] # take <s> token (equiv. to [CLS]) |
|
x = self.dropout(x) |
|
x = self.dense(x) |
|
x = ACT2FN[self.config.hidden_act](x) # 这里是relu |
|
x = self.dropout(x) |
|
x = self.out_proj(x) |
|
return x |
|
``` |
|
|
|
## pytorch & tf2.0使用 |
|
```python |
|
import torch |
|
import tensorflow as tf |
|
from transformers import BertTokenizer |
|
from roformer import RoFormerForMaskedLM, TFRoFormerForMaskedLM |
|
|
|
text = "今天[MASK]很好,我[MASK]去公园玩。" |
|
tokenizer = BertTokenizer.from_pretrained("junnyu/roformer_v2_chinese_char_base") |
|
pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_v2_chinese_char_base") |
|
tf_model = TFRoFormerForMaskedLM.from_pretrained( |
|
"junnyu/roformer_v2_chinese_char_base", from_pt=True |
|
) |
|
pt_inputs = tokenizer(text, return_tensors="pt") |
|
tf_inputs = tokenizer(text, return_tensors="tf") |
|
# pytorch |
|
with torch.no_grad(): |
|
pt_outputs = pt_model(**pt_inputs).logits[0] |
|
pt_outputs_sentence = "pytorch: " |
|
for i, id in enumerate(tokenizer.encode(text)): |
|
if id == tokenizer.mask_token_id: |
|
tokens = tokenizer.convert_ids_to_tokens(pt_outputs[i].topk(k=5)[1]) |
|
pt_outputs_sentence += "[" + "||".join(tokens) + "]" |
|
else: |
|
pt_outputs_sentence += "".join( |
|
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) |
|
) |
|
print(pt_outputs_sentence) |
|
# tf |
|
tf_outputs = tf_model(**tf_inputs, training=False).logits[0] |
|
tf_outputs_sentence = "tf: " |
|
for i, id in enumerate(tokenizer.encode(text)): |
|
if id == tokenizer.mask_token_id: |
|
tokens = tokenizer.convert_ids_to_tokens(tf.math.top_k(tf_outputs[i], k=5)[1]) |
|
tf_outputs_sentence += "[" + "||".join(tokens) + "]" |
|
else: |
|
tf_outputs_sentence += "".join( |
|
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) |
|
) |
|
print(tf_outputs_sentence) |
|
# small |
|
# pytorch: 今天[的||,||是||很||也]很好,我[要||会||是||想||在]去公园玩。 |
|
# tf: 今天[的||,||是||很||也]很好,我[要||会||是||想||在]去公园玩。 |
|
# base |
|
# pytorch: 今天[我||天||晴||园||玩]很好,我[想||要||会||就||带]去公园玩。 |
|
# tf: 今天[我||天||晴||园||玩]很好,我[想||要||会||就||带]去公园玩。 |
|
# large |
|
# pytorch: 今天[天||气||我||空||阳]很好,我[又||想||会||就||爱]去公园玩。 |
|
# tf: 今天[天||气||我||空||阳]很好,我[又||想||会||就||爱]去公园玩。 |
|
``` |
|
|
|
## 引用 |
|
Bibtex: |
|
```tex |
|
@misc{su2021roformer, |
|
title={RoFormer: Enhanced Transformer with Rotary Position Embedding}, |
|
author={Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu}, |
|
year={2021}, |
|
eprint={2104.09864}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.CL} |
|
} |
|
``` |
|
```tex |
|
@techreport{roformerv2, |
|
title={RoFormerV2: A Faster and Better RoFormer - ZhuiyiAI}, |
|
author={Jianlin Su, Shengfeng Pan, Bo Wen, Yunfeng Liu}, |
|
year={2022}, |
|
url="https://github.com/ZhuiyiTechnology/roformer-v2", |
|
} |
|
``` |