gchhablani commited on
Commit
875bee5
1 Parent(s): 63672a5
Files changed (1) hide show
  1. apps/mlm.py +2 -2
apps/mlm.py CHANGED
@@ -53,7 +53,7 @@ def app(state):
53
  mlm_state.unmasked_caption = caption
54
  ids = bert_tokenizer.encode(caption)
55
  mask_index = np.random.randint(1, len(ids) - 1)
56
- mlm_state.currently_masked_token = bert_tokenizer.convert_ids_to_tokens[[ids[mask_index]]][0]
57
  ids[mask_index] = bert_tokenizer.mask_token_id
58
  mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
59
  mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
@@ -83,7 +83,7 @@ def app(state):
83
  mlm_state.unmasked_caption = caption
84
  ids = bert_tokenizer.encode(caption)
85
  mask_index = np.random.randint(1, len(ids) - 1)
86
- mlm_state.currently_masked_token = bert_tokenizer.convert_ids_to_tokens[[ids[mask_index]]][0]
87
  ids[mask_index] = bert_tokenizer.mask_token_id
88
  mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
89
  mlm_state.caption_lang_id = sample.loc[0, "lang_id"]
 
53
  mlm_state.unmasked_caption = caption
54
  ids = bert_tokenizer.encode(caption)
55
  mask_index = np.random.randint(1, len(ids) - 1)
56
+ mlm_state.currently_masked_token = bert_tokenizer.convert_ids_to_tokens([ids[mask_index]])[0]
57
  ids[mask_index] = bert_tokenizer.mask_token_id
58
  mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
59
  mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
 
83
  mlm_state.unmasked_caption = caption
84
  ids = bert_tokenizer.encode(caption)
85
  mask_index = np.random.randint(1, len(ids) - 1)
86
+ mlm_state.currently_masked_token = bert_tokenizer.convert_ids_to_tokens([ids[mask_index]])[0]
87
  ids[mask_index] = bert_tokenizer.mask_token_id
88
  mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
89
  mlm_state.caption_lang_id = sample.loc[0, "lang_id"]