gagan3012 commited on
Commit
c015c4c
1 Parent(s): 322ebac
.gitignore CHANGED
@@ -88,3 +88,6 @@ coverage.xml
88
 
89
  # Mypy cache
90
  .mypy_cache/
 
 
 
 
88
 
89
  # Mypy cache
90
  .mypy_cache/
91
+
92
+ .idea
93
+ .vscode
src/data/make_dataset.py CHANGED
@@ -9,3 +9,6 @@ def make_dataset(dataset='cnn_dailymail', split='train', version="3.0.0"):
9
  df['input_text'] = dataset['concepts']
10
  df['output_text'] = dataset['target']
11
  return df
 
 
 
 
9
  df['input_text'] = dataset['concepts']
10
  df['output_text'] = dataset['target']
11
  return df
12
+
13
+ if __name__ == '__main__':
14
+ make_dataset(dataset='cnn_dailymail', split='train', version="3.0.0")
src/models/model.py CHANGED
@@ -340,7 +340,7 @@ class Summarization:
340
  trainer.fit(self.T5Model, self.data_module)
341
 
342
  def load_model(
343
- self, model_dir: str = "models", use_gpu: bool = False
344
  ):
345
  """
346
  loads a checkpoint for inferencing/prediction
 
340
  trainer.fit(self.T5Model, self.data_module)
341
 
342
  def load_model(
343
+ self, model_dir: str = "../../models", use_gpu: bool = False
344
  ):
345
  """
346
  loads a checkpoint for inferencing/prediction
src/models/predict_model.py CHANGED
@@ -1,2 +1,11 @@
1
  from .model import Summarization
2
 
 
 
 
 
 
 
 
 
 
 
1
  from .model import Summarization
2
 
3
+ def predict_model(text):
4
+ """
5
+ Predict the summary of the given text.
6
+ """
7
+ model = Summarization()
8
+ model.load_model()
9
+ pre_summary = model.predict(text)
10
+ return pre_summary
11
+
src/models/train_model.py CHANGED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model import Summarization
2
+ from data.make_dataset import make_dataset
3
+
4
+ def train_model():
5
+ """
6
+ Train the model
7
+ """
8
+ # Load the data
9
+ train_df = make_dataset(split = 'train')
10
+ eval_df = make_dataset(split = 'test')
11
+
12
+ model = Summarization()
13
+ model.from_pretrained('t5-base')
14
+ model.train(train_df=train_df, eval_df=eval_df, batch_size=4, max_epochs=3, use_gpu=True)
15
+ model.save_model()