Gagan Bhatia commited on
Commit
2679662
1 Parent(s): 4ac518a

Update model.py

Browse files
Files changed (1) hide show
  1. src/models/model.py +15 -5
src/models/model.py CHANGED
@@ -252,24 +252,34 @@ class LightningModel(LightningModule):
252
  no_decay = ["bias", "LayerNorm.weight"]
253
  optimizer_grouped_parameters = [
254
  {
255
- "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
 
 
 
 
256
  "weight_decay": self.weight_decay,
257
  },
258
  {
259
- "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
 
 
 
 
260
  "weight_decay": 0.0,
261
  },
262
  ]
263
- optimizer = AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=self.adam_epsilon)
 
 
264
  self.opt = optimizer
265
  return [optimizer]
266
 
267
 
268
  class Summarization:
269
- """ Custom Summarization class """
270
 
271
  def __init__(self) -> None:
272
- """ initiates Summarization class """
273
  pass
274
 
275
  def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
 
252
  no_decay = ["bias", "LayerNorm.weight"]
253
  optimizer_grouped_parameters = [
254
  {
255
+ "params": [
256
+ p
257
+ for n, p in model.named_parameters()
258
+ if not any(nd in n for nd in no_decay)
259
+ ],
260
  "weight_decay": self.weight_decay,
261
  },
262
  {
263
+ "params": [
264
+ p
265
+ for n, p in model.named_parameters()
266
+ if any(nd in n for nd in no_decay)
267
+ ],
268
  "weight_decay": 0.0,
269
  },
270
  ]
271
+ optimizer = AdamW(
272
+ optimizer_grouped_parameters, lr=self.learning_rate, eps=self.adam_epsilon
273
+ )
274
  self.opt = optimizer
275
  return [optimizer]
276
 
277
 
278
  class Summarization:
279
+ """Custom Summarization class"""
280
 
281
  def __init__(self) -> None:
282
+ """initiates Summarization class"""
283
  pass
284
 
285
  def from_pretrained(self, model_type="t5", model_name="t5-base") -> None: