sidphbot commited on
Commit
2b3ea65
1 Parent(s): 5727aa4

conditional_gen adjustments

Browse files
Files changed (1) hide show
  1. src/Surveyor.py +16 -3
src/Surveyor.py CHANGED
@@ -146,6 +146,16 @@ class Surveyor:
146
  else:
147
  self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
148
  self.ledmodel = T5ForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
 
 
 
 
 
 
 
 
 
 
149
  self.ledmodel.eval()
150
  if not no_save_models:
151
  self.ledmodel.save_pretrained(models_dir + "/ledmodel")
@@ -170,12 +180,15 @@ class Surveyor:
170
  self.summ_model.eval()
171
  self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
172
 
173
- if 't5' not in ledmodel_name:
174
  self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
175
  self.ledmodel = LEDForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
176
- else:
177
- self.ledtokenizer = T5Tokenizer.from_pretrained(ledmodel_name)
178
  self.ledmodel = T5ForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
 
 
 
179
  self.ledmodel.eval()
180
 
181
  self.embedder = SentenceTransformer(models_dir + "/embedder")
 
146
  else:
147
  self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
148
  self.ledmodel = T5ForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
149
+
150
+ if 'led' in ledmodel_name:
151
+ self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
152
+ self.ledmodel = LEDForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
153
+ elif 't5' in ledmodel_name:
154
+ self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
155
+ self.ledmodel = T5ForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
156
+ elif 'bart' in ledmodel_name:
157
+ self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
158
+ self.ledmodel = BartForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
159
  self.ledmodel.eval()
160
  if not no_save_models:
161
  self.ledmodel.save_pretrained(models_dir + "/ledmodel")
 
180
  self.summ_model.eval()
181
  self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
182
 
183
+ if 'led' in ledmodel_name:
184
  self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
185
  self.ledmodel = LEDForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
186
+ elif 't5' in ledmodel_name:
187
+ self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
188
  self.ledmodel = T5ForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
189
+ elif 'bart' in ledmodel_name:
190
+ self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
191
+ self.ledmodel = BartForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
192
  self.ledmodel.eval()
193
 
194
  self.embedder = SentenceTransformer(models_dir + "/embedder")