Spaces:
Runtime error
Runtime error
Gagan Bhatia
commited on
Commit
•
7c59938
1
Parent(s):
16c3afe
Update model.py
Browse files- src/models/model.py +13 -13
src/models/model.py
CHANGED
@@ -433,19 +433,19 @@ class Summarization:
|
|
433 |
self.model.save_pretrained(path)
|
434 |
|
435 |
def predict(
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
):
|
450 |
"""
|
451 |
generates prediction for T5/MT5 model
|
|
|
433 |
self.model.save_pretrained(path)
|
434 |
|
435 |
def predict(
|
436 |
+
self,
|
437 |
+
source_text: str,
|
438 |
+
max_length: int = 512,
|
439 |
+
num_return_sequences: int = 1,
|
440 |
+
num_beams: int = 2,
|
441 |
+
top_k: int = 50,
|
442 |
+
top_p: float = 0.95,
|
443 |
+
do_sample: bool = True,
|
444 |
+
repetition_penalty: float = 2.5,
|
445 |
+
length_penalty: float = 1.0,
|
446 |
+
early_stopping: bool = True,
|
447 |
+
skip_special_tokens: bool = True,
|
448 |
+
clean_up_tokenization_spaces: bool = True,
|
449 |
):
|
450 |
"""
|
451 |
generates prediction for T5/MT5 model
|