gagan3012 commited on
Commit
d1aa7b9
1 Parent(s): fc24771

added params

Browse files
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
- numpy==1.19.2
2
- datasets==1.8.0
3
  pytorch_lightning==1.3.5
4
- transformers==4.6.0
5
  torch==1.9.0+cu111
6
  dagshub==0.1.6
7
  pandas==1.2.4
1
+ numpy==1.21.1
2
+ datasets==1.10.2
3
  pytorch_lightning==1.3.5
4
+ transformers==4.9.0
5
  torch==1.9.0+cu111
6
  dagshub==0.1.6
7
  pandas==1.2.4
src/models/model.py CHANGED
@@ -1,6 +1,4 @@
1
- import time
2
  import torch
3
- import numpy as np
4
  import pandas as pd
5
  from dagshub.pytorch_lightning import DAGsHubLogger
6
  from transformers import (
@@ -319,7 +317,7 @@ class Summarization:
319
 
320
  self.T5Model = LightningModel(
321
  tokenizer=self.tokenizer, model=self.model, output=outputdir,
322
- learning_rate=learning_rate,adam_epsilon=adam_epsilon
323
  )
324
 
325
  MLlogger = MLFlowLogger(experiment_name="Summarization",
 
1
  import torch
 
2
  import pandas as pd
3
  from dagshub.pytorch_lightning import DAGsHubLogger
4
  from transformers import (
317
 
318
  self.T5Model = LightningModel(
319
  tokenizer=self.tokenizer, model=self.model, output=outputdir,
320
+ learning_rate=learning_rate, adam_epsilon=adam_epsilon
321
  )
322
 
323
  MLlogger = MLFlowLogger(experiment_name="Summarization",
src/models/predict_model.py CHANGED
@@ -1,7 +1,7 @@
1
- from src.data.make_dataset import make_dataset
2
  from .model import Summarization
3
  import pandas as pd
4
 
 
5
  def predict_model(text):
6
  """
7
  Predict the summary of the given text.
@@ -11,8 +11,8 @@ def predict_model(text):
11
  pre_summary = model.predict(text)
12
  return pre_summary
13
 
14
-
15
  if __name__ == '__main__':
16
  text = pd.load_csv('data/processed/test.csv')['input_text'][0]
17
  pre_summary = predict_model(text)
18
- print(pre_summary)
 
1
  from .model import Summarization
2
  import pandas as pd
3
 
4
+
5
  def predict_model(text):
6
  """
7
  Predict the summary of the given text.
11
  pre_summary = model.predict(text)
12
  return pre_summary
13
 
14
+
15
  if __name__ == '__main__':
16
  text = pd.load_csv('data/processed/test.csv')['input_text'][0]
17
  pre_summary = predict_model(text)
18
+ print(pre_summary)
tox.ini CHANGED
@@ -1,3 +1,3 @@
1
  [flake8]
2
- max-line-length = 79
3
  max-complexity = 10
1
  [flake8]
2
+ max-line-length = 160
3
  max-complexity = 10