w11wo commited on
Commit
eb227e9
1 Parent(s): 1432ed9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -5
README.md CHANGED
@@ -126,12 +126,12 @@ This model was trained using the [Keras](https://keras.io/) framework. All train
126
  import keras
127
  import tensorflow as tf
128
  import numpy as np
 
129
 
130
- mlm_model = keras.models.load_model(
131
- "bert_mlm.h5", custom_objects={"MaskedLanguageModel": MaskedLanguageModel}
132
- )
133
 
134
  MAX_LEN = 32
 
135
 
136
  def inference(sequence):
137
  sequence = " ".join([c if c != "e" else "[mask]" for c in sequence])
@@ -140,10 +140,10 @@ def inference(sequence):
140
 
141
  tokens = tokens + pad
142
  input_ids = tf.convert_to_tensor(np.array([tokens]))
143
- prediction = mlm_model.predict(input_ids)
144
 
145
  # find masked idx token
146
- masked_index = np.where(input_ids == mask_token_id)
147
  masked_index = masked_index[1]
148
 
149
  # get prediction at those masked index only
126
  import keras
127
  import tensorflow as tf
128
  import numpy as np
129
+ from huggingface_hub import from_pretrained_keras
130
 
131
+ model = from_pretrained_keras("bookbot/id-g2p-bert")
 
 
132
 
133
  MAX_LEN = 32
134
+ MASK_TOKEN_ID = 30
135
 
136
  def inference(sequence):
137
  sequence = " ".join([c if c != "e" else "[mask]" for c in sequence])
140
 
141
  tokens = tokens + pad
142
  input_ids = tf.convert_to_tensor(np.array([tokens]))
143
+ prediction = model.predict(input_ids)
144
 
145
  # find masked idx token
146
+ masked_index = np.where(input_ids == MASK_TOKEN_ID)
147
  masked_index = masked_index[1]
148
 
149
  # get prediction at those masked index only