Spaces:
Build error
Build error
t5 led exchange
Browse files- src/Surveyor.py +13 -5
src/Surveyor.py
CHANGED
@@ -131,8 +131,12 @@ class Surveyor:
|
|
131 |
#self.summ_tokenizer.save_pretrained(models_dir + "/summ_tokenizer")
|
132 |
self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
|
133 |
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
136 |
self.ledmodel.eval()
|
137 |
if not no_save_models:
|
138 |
self.ledmodel.save_pretrained(models_dir + "/ledmodel")
|
@@ -144,7 +148,7 @@ class Surveyor:
|
|
144 |
self.embedder.save(models_dir + "/embedder")
|
145 |
else:
|
146 |
print("\nInitializing from previously saved models at" + models_dir)
|
147 |
-
self.title_tokenizer = AutoTokenizer.from_pretrained(title_model_name)
|
148 |
self.title_model = AutoModelForSeq2SeqLM.from_pretrained(models_dir + "/title_model").to(self.torch_device)
|
149 |
self.title_model.eval()
|
150 |
|
@@ -157,8 +161,12 @@ class Surveyor:
|
|
157 |
self.summ_model.eval()
|
158 |
self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
|
159 |
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
162 |
self.ledmodel.eval()
|
163 |
|
164 |
self.embedder = SentenceTransformer(models_dir + "/embedder")
|
|
|
131 |
#self.summ_tokenizer.save_pretrained(models_dir + "/summ_tokenizer")
|
132 |
self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
|
133 |
|
134 |
+
if 't5' not in ledmodel_name:
|
135 |
+
self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
|
136 |
+
self.ledmodel = LEDForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
|
137 |
+
else:
|
138 |
+
self.ledtokenizer = T5Tokenizer.from_pretrained(ledmodel_name)
|
139 |
+
self.ledmodel = T5ForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
|
140 |
self.ledmodel.eval()
|
141 |
if not no_save_models:
|
142 |
self.ledmodel.save_pretrained(models_dir + "/ledmodel")
|
|
|
148 |
self.embedder.save(models_dir + "/embedder")
|
149 |
else:
|
150 |
print("\nInitializing from previously saved models at" + models_dir)
|
151 |
+
self.title_tokenizer = AutoTokenizer.from_pretrained(title_model_name)
|
152 |
self.title_model = AutoModelForSeq2SeqLM.from_pretrained(models_dir + "/title_model").to(self.torch_device)
|
153 |
self.title_model.eval()
|
154 |
|
|
|
161 |
self.summ_model.eval()
|
162 |
self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
|
163 |
|
164 |
+
if 't5' not in ledmodel_name:
|
165 |
+
self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
|
166 |
+
self.ledmodel = LEDForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
|
167 |
+
else:
|
168 |
+
self.ledtokenizer = T5Tokenizer.from_pretrained(ledmodel_name)
|
169 |
+
self.ledmodel = T5ForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
|
170 |
self.ledmodel.eval()
|
171 |
|
172 |
self.embedder = SentenceTransformer(models_dir + "/embedder")
|