Gagan Bhatia commited on
Commit
c30637a
1 Parent(s): 83a1f94

Update model.py

Browse files
Files changed (1) hide show
  1. src/models/model.py +1 -0
src/models/model.py CHANGED
@@ -263,6 +263,7 @@ class Summarization:
263
  elif model_type == "mt5":
264
  self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_name}")
265
  self.model = MT5ForConditionalGeneration.from_pretrained(
 
266
 
267
  def train(
268
  self,
 
263
  elif model_type == "mt5":
264
  self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_name}")
265
  self.model = MT5ForConditionalGeneration.from_pretrained(
266
+ f"{model_name}", return_dict=True
267
 
268
  def train(
269
  self,