Spaces:
Build error
Build error
conditional_gen adjustments
Browse files- src/Surveyor.py +16 -3
src/Surveyor.py
CHANGED
@@ -146,6 +146,16 @@ class Surveyor:
|
|
146 |
else:
|
147 |
self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
|
148 |
self.ledmodel = T5ForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
self.ledmodel.eval()
|
150 |
if not no_save_models:
|
151 |
self.ledmodel.save_pretrained(models_dir + "/ledmodel")
|
@@ -170,12 +180,15 @@ class Surveyor:
|
|
170 |
self.summ_model.eval()
|
171 |
self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
|
172 |
|
173 |
-
if '
|
174 |
self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
|
175 |
self.ledmodel = LEDForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
|
176 |
-
|
177 |
-
self.ledtokenizer =
|
178 |
self.ledmodel = T5ForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
|
|
|
|
|
|
|
179 |
self.ledmodel.eval()
|
180 |
|
181 |
self.embedder = SentenceTransformer(models_dir + "/embedder")
|
|
|
146 |
else:
|
147 |
self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
|
148 |
self.ledmodel = T5ForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
|
149 |
+
|
150 |
+
if 'led' in ledmodel_name:
|
151 |
+
self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
|
152 |
+
self.ledmodel = LEDForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
|
153 |
+
elif 't5' in ledmodel_name:
|
154 |
+
self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
|
155 |
+
self.ledmodel = T5ForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
|
156 |
+
elif 'bart' in ledmodel_name:
|
157 |
+
self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
|
158 |
+
self.ledmodel = BartForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
|
159 |
self.ledmodel.eval()
|
160 |
if not no_save_models:
|
161 |
self.ledmodel.save_pretrained(models_dir + "/ledmodel")
|
|
|
180 |
self.summ_model.eval()
|
181 |
self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
|
182 |
|
183 |
+
if 'led' in ledmodel_name:
|
184 |
self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
|
185 |
self.ledmodel = LEDForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
|
186 |
+
elif 't5' in ledmodel_name:
|
187 |
+
self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
|
188 |
self.ledmodel = T5ForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
|
189 |
+
elif 'bart' in ledmodel_name:
|
190 |
+
self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
|
191 |
+
self.ledmodel = BartForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
|
192 |
self.ledmodel.eval()
|
193 |
|
194 |
self.embedder = SentenceTransformer(models_dir + "/embedder")
|