cspocketindia commited on
Commit
eb76698
1 Parent(s): 2e76149

Update tokenizer

Browse files
Files changed (1) hide show
  1. src/models/bert.py +3 -2
src/models/bert.py CHANGED
@@ -8,7 +8,7 @@ from torch import nn
8
  from transformers import RobertaTokenizer, RobertaModel, AdamW, RobertaConfig
9
  from sklearn.model_selection import train_test_split
10
 
11
- from transformers import BertTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup, AutoModel
12
  from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler, SequentialSampler
13
 
14
 
@@ -20,7 +20,8 @@ class BERTClassifier():
20
 
21
  self.model_name = model_name
22
 
23
- self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name, do_lower_case=True)
 
24
 
25
  if model_name.startswith('jeevavijay10'):
26
  # self.model = torch.load(model_name)
 
8
  from transformers import RobertaTokenizer, RobertaModel, AdamW, RobertaConfig
9
  from sklearn.model_selection import train_test_split
10
 
11
+ from transformers import BertTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup, AutoModel, AutoTokenizer
12
  from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler, SequentialSampler
13
 
14
 
 
20
 
21
  self.model_name = model_name
22
 
23
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
24
+ # self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name, do_lower_case=True)
25
 
26
  if model_name.startswith('jeevavijay10'):
27
  # self.model = torch.load(model_name)