sonsus commited on
Commit
399209a
1 Parent(s): 8e6a988

BatchEncoding wrapper for custom tokenizer output

Browse files
Files changed (1) hide show
  1. harim_plus.py +14 -4
harim_plus.py CHANGED
@@ -7,9 +7,12 @@ import torch.nn.functional as F
7
  from transformers import (AutoModelForSeq2SeqLM,
8
  AutoTokenizer,
9
  PreTrainedTokenizer,
10
- PreTrainedTokenizerFast)
 
 
 
11
  import pandas as pd
12
- from tqdm import tqdm
13
 
14
  from typing import List, Dict, Union
15
  from collections import defaultdict
@@ -201,8 +204,15 @@ class Harimplus_Scorer:
201
  emp_in = self._prep_input( mini_e_, src_or_tgt='src' )
202
 
203
 
204
- tgt_mask = tgt_in.attention_mask
205
-
 
 
 
 
 
 
 
206
  src_in = src_in.to(self._device)
207
  emp_in = emp_in.to(self._device)
208
  tgt_in = tgt_in.to(self._device)
 
7
  from transformers import (AutoModelForSeq2SeqLM,
8
  AutoTokenizer,
9
  PreTrainedTokenizer,
10
+ PreTrainedTokenizerFast,
11
+ )
12
+ from transformers.tokenization_utils_base import BatchEncoding # for custom tokenizer other than huggingface
13
+
14
  import pandas as pd
15
+ from tqdm import tqdme
16
 
17
  from typing import List, Dict, Union
18
  from collections import defaultdict
 
204
  emp_in = self._prep_input( mini_e_, src_or_tgt='src' )
205
 
206
 
207
+ tgt_mask = tgt_in.attention_mask # torch.Tensor
208
+ # if not tokenizer loaded from huggingface, this might cause some problem (.to(device))
209
+ if not isinstance(src_in, BatchEncoding):
210
+ src_in = BatchEncoding(src_in)
211
+ if not isinstance(emp_in, BatchEncoding):
212
+ emp_in = BatchEncoding(emp_in)
213
+ if not isinstance(tgt_in, BatchEncoding):
214
+ tgt_in = BatchEncoding(tgt_in)
215
+
216
  src_in = src_in.to(self._device)
217
  emp_in = emp_in.to(self._device)
218
  tgt_in = tgt_in.to(self._device)