Gagan Bhatia commited on
Commit
7c59938
1 Parent(s): 16c3afe

Update model.py

Browse files
Files changed (1) hide show
  1. src/models/model.py +13 -13
src/models/model.py CHANGED
@@ -433,19 +433,19 @@ class Summarization:
433
  self.model.save_pretrained(path)
434
 
435
  def predict(
436
- self,
437
- source_text: str,
438
- max_length: int = 512,
439
- num_return_sequences: int = 1,
440
- num_beams: int = 2,
441
- top_k: int = 50,
442
- top_p: float = 0.95,
443
- do_sample: bool = True,
444
- repetition_penalty: float = 2.5,
445
- length_penalty: float = 1.0,
446
- early_stopping: bool = True,
447
- skip_special_tokens: bool = True,
448
- clean_up_tokenization_spaces: bool = True,
449
  ):
450
  """
451
  generates prediction for T5/MT5 model
 
433
  self.model.save_pretrained(path)
434
 
435
  def predict(
436
+ self,
437
+ source_text: str,
438
+ max_length: int = 512,
439
+ num_return_sequences: int = 1,
440
+ num_beams: int = 2,
441
+ top_k: int = 50,
442
+ top_p: float = 0.95,
443
+ do_sample: bool = True,
444
+ repetition_penalty: float = 2.5,
445
+ length_penalty: float = 1.0,
446
+ early_stopping: bool = True,
447
+ skip_special_tokens: bool = True,
448
+ clean_up_tokenization_spaces: bool = True,
449
  ):
450
  """
451
  generates prediction for T5/MT5 model