tianyuz commited on
Commit
1559edf
1 Parent(s): 22976fc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +20 -12
README.md CHANGED
@@ -45,6 +45,10 @@ To predict a masked token, be sure to add a `[CLS]` token before the sentence fo
45
 
46
  A) Directly typing `[MASK]` in an input string and B) replacing a token with `[MASK]` after tokenization will yield different token sequences, and thus different prediction results. It is more appropriate to use `[MASK]` after tokenization (as it is consistent with how the model was pretrained). However, the Huggingface Inference API only supports typing `[MASK]` in the input string and produces less robust predictions.
47
 
 
 
 
 
48
  ## Example
49
 
50
  Here is an example by to illustrate how our model works as a masked language model. Notice the difference between running the following code example and running the Huggingface Inference API.
@@ -71,12 +75,16 @@ print(token_ids) # output: [4, 1602, 44, 24, 368, 6, 11, 21583, 8]
71
 
72
  # convert to tensor
73
  import torch
74
- token_tensor = torch.tensor([token_ids])
 
 
 
 
 
75
 
76
  # get the top 10 predictions of the masked token
77
- model = model.eval()
78
  with torch.no_grad():
79
- outputs = model(token_tensor)
80
  predictions = outputs[0][0, masked_idx].topk(10)
81
 
82
  for i, index_t in enumerate(predictions.indices):
@@ -85,16 +93,16 @@ for i, index_t in enumerate(predictions.indices):
85
  print(i, token)
86
 
87
  """
88
- 0 ワールドカップ
89
- 1 フェスティバル
90
- 2 オリンピック
91
- 3 サミット
92
- 4 東京オリンピック
93
- 5 総会
94
  6 全国大会
95
- 7 イベント
96
- 8 世界選手権
97
- 9 パーティー
98
  """
99
  ~~~~
100
 
 
45
 
46
  A) Directly typing `[MASK]` in an input string and B) replacing a token with `[MASK]` after tokenization will yield different token sequences, and thus different prediction results. It is more appropriate to use `[MASK]` after tokenization (as it is consistent with how the model was pretrained). However, the Huggingface Inference API only supports typing `[MASK]` in the input string and produces less robust predictions.
47
 
48
+ ## Note 3: Provide `position_ids` as an argument explicitly
49
+
50
+ When `position_ids` are not provided for a `Roberta*` model, Huggingface's `transformers` will automatically construct it but start from `padding_idx` instead of `0` (see [issue](https://github.com/rinnakk/japanese-pretrained-models/issues/3) and function `create_position_ids_from_input_ids()` in Huggingface's [implementation](https://github.com/huggingface/transformers/blob/master/src/transformers/models/roberta/modeling_roberta.py)), which unfortunately does not work as expected with `rinna/japanese-roberta-base` since the `padding_idx` of the corresponding tokenizer is not `0`. So please be sure to constrcut the `position_ids` by yourself and make it start from position id `0`.
51
+
52
  ## Example
53
 
54
  Here is an example by to illustrate how our model works as a masked language model. Notice the difference between running the following code example and running the Huggingface Inference API.
 
75
 
76
  # convert to tensor
77
  import torch
78
+ token_tensor = torch.LongTensor([token_ids])
79
+
80
+ # provide position ids explicitly
81
+ position_ids = list(range(0, token_tensor.size(1)))
82
+ print(position_ids) # output: [0, 1, 2, 3, 4, 5, 6, 7, 8]
83
+ position_id_tensor = torch.LongTensor([position_ids])
84
 
85
  # get the top 10 predictions of the masked token
 
86
  with torch.no_grad():
87
+ outputs = model(input_ids=token_tensor, position_ids=position_id_tensor)
88
  predictions = outputs[0][0, masked_idx].topk(10)
89
 
90
  for i, index_t in enumerate(predictions.indices):
 
93
  print(i, token)
94
 
95
  """
96
+ 0 総会
97
+ 1 サミット
98
+ 2 ワールドカップ
99
+ 3 フェスティバル
100
+ 4 大会
101
+ 5 オリンピック
102
  6 全国大会
103
+ 7 党大会
104
+ 8 イベント
105
+ 9 世界選手権
106
  """
107
  ~~~~
108