Spaces:
Runtime error
Runtime error
Updates
Browse files- .gitignore +3 -0
- src/data/make_dataset.py +3 -0
- src/models/model.py +1 -1
- src/models/predict_model.py +9 -0
- src/models/train_model.py +15 -0
.gitignore
CHANGED
@@ -88,3 +88,6 @@ coverage.xml
|
|
88 |
|
89 |
# Mypy cache
|
90 |
.mypy_cache/
|
|
|
|
|
|
|
|
88 |
|
89 |
# Mypy cache
|
90 |
.mypy_cache/
|
91 |
+
|
92 |
+
.idea
|
93 |
+
.vscode
|
src/data/make_dataset.py
CHANGED
@@ -9,3 +9,6 @@ def make_dataset(dataset='cnn_dailymail', split='train', version="3.0.0"):
|
|
9 |
df['input_text'] = dataset['concepts']
|
10 |
df['output_text'] = dataset['target']
|
11 |
return df
|
|
|
|
|
|
|
|
9 |
df['input_text'] = dataset['concepts']
|
10 |
df['output_text'] = dataset['target']
|
11 |
return df
|
12 |
+
|
13 |
+
if __name__ == '__main__':
|
14 |
+
make_dataset(dataset='cnn_dailymail', split='train', version="3.0.0")
|
src/models/model.py
CHANGED
@@ -340,7 +340,7 @@ class Summarization:
|
|
340 |
trainer.fit(self.T5Model, self.data_module)
|
341 |
|
342 |
def load_model(
|
343 |
-
self, model_dir: str = "models", use_gpu: bool = False
|
344 |
):
|
345 |
"""
|
346 |
loads a checkpoint for inferencing/prediction
|
|
|
340 |
trainer.fit(self.T5Model, self.data_module)
|
341 |
|
342 |
def load_model(
|
343 |
+
self, model_dir: str = "../../models", use_gpu: bool = False
|
344 |
):
|
345 |
"""
|
346 |
loads a checkpoint for inferencing/prediction
|
src/models/predict_model.py
CHANGED
@@ -1,2 +1,11 @@
|
|
1 |
from .model import Summarization
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from .model import Summarization
|
2 |
|
3 |
+
def predict_model(text):
|
4 |
+
"""
|
5 |
+
Predict the summary of the given text.
|
6 |
+
"""
|
7 |
+
model = Summarization()
|
8 |
+
model.load_model()
|
9 |
+
pre_summary = model.predict(text)
|
10 |
+
return pre_summary
|
11 |
+
|
src/models/train_model.py
CHANGED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .model import Summarization
|
2 |
+
from data.make_dataset import make_dataset
|
3 |
+
|
4 |
+
def train_model():
|
5 |
+
"""
|
6 |
+
Train the model
|
7 |
+
"""
|
8 |
+
# Load the data
|
9 |
+
train_df = make_dataset(split = 'train')
|
10 |
+
eval_df = make_dataset(split = 'test')
|
11 |
+
|
12 |
+
model = Summarization()
|
13 |
+
model.from_pretrained('t5-base')
|
14 |
+
model.train(train_df=train_df, eval_df=eval_df, batch_size=4, max_epochs=3, use_gpu=True)
|
15 |
+
model.save_model()
|