sidphbot commited on
Commit
d1763ec
1 Parent(s): 09d7bdd

t5 led exchange

Browse files
Files changed (1) hide show
  1. src/Surveyor.py +13 -5
src/Surveyor.py CHANGED
@@ -131,8 +131,12 @@ class Surveyor:
131
  #self.summ_tokenizer.save_pretrained(models_dir + "/summ_tokenizer")
132
  self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
133
 
134
- self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
135
- self.ledmodel = LEDForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
 
 
 
 
136
  self.ledmodel.eval()
137
  if not no_save_models:
138
  self.ledmodel.save_pretrained(models_dir + "/ledmodel")
@@ -144,7 +148,7 @@ class Surveyor:
144
  self.embedder.save(models_dir + "/embedder")
145
  else:
146
  print("\nInitializing from previously saved models at" + models_dir)
147
- self.title_tokenizer = AutoTokenizer.from_pretrained(title_model_name).to(self.torch_device)
148
  self.title_model = AutoModelForSeq2SeqLM.from_pretrained(models_dir + "/title_model").to(self.torch_device)
149
  self.title_model.eval()
150
 
@@ -157,8 +161,12 @@ class Surveyor:
157
  self.summ_model.eval()
158
  self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
159
 
160
- self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
161
- self.ledmodel = LEDForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
 
 
 
 
162
  self.ledmodel.eval()
163
 
164
  self.embedder = SentenceTransformer(models_dir + "/embedder")
 
131
  #self.summ_tokenizer.save_pretrained(models_dir + "/summ_tokenizer")
132
  self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
133
 
134
+ if 't5' not in ledmodel_name:
135
+ self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
136
+ self.ledmodel = LEDForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
137
+ else:
138
+ self.ledtokenizer = T5Tokenizer.from_pretrained(ledmodel_name)
139
+ self.ledmodel = T5ForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
140
  self.ledmodel.eval()
141
  if not no_save_models:
142
  self.ledmodel.save_pretrained(models_dir + "/ledmodel")
 
148
  self.embedder.save(models_dir + "/embedder")
149
  else:
150
  print("\nInitializing from previously saved models at" + models_dir)
151
+ self.title_tokenizer = AutoTokenizer.from_pretrained(title_model_name)
152
  self.title_model = AutoModelForSeq2SeqLM.from_pretrained(models_dir + "/title_model").to(self.torch_device)
153
  self.title_model.eval()
154
 
 
161
  self.summ_model.eval()
162
  self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
163
 
164
+ if 't5' not in ledmodel_name:
165
+ self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
166
+ self.ledmodel = LEDForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
167
+ else:
168
+ self.ledtokenizer = T5Tokenizer.from_pretrained(ledmodel_name)
169
+ self.ledmodel = T5ForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
170
  self.ledmodel.eval()
171
 
172
  self.embedder = SentenceTransformer(models_dir + "/embedder")