p-christ commited on
Commit
1f06868
1 Parent(s): ac12f43

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +5 -1
pipeline.py CHANGED
@@ -5,7 +5,7 @@ from nltk import sent_tokenize
5
  import nltk
6
 
7
  class PreTrainedPipeline():
8
- model_type="t5"
9
  def __init__(self, path=""):
10
  # IMPLEMENT_THIS
11
  # Preload all the elements you are going to need at inference.
@@ -14,6 +14,10 @@ class PreTrainedPipeline():
14
  nltk.download('punkt')
15
  self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
16
  self.tokenizer = AutoTokenizer.from_pretrained(path)
 
 
 
 
17
 
18
 
19
  def __call__(self, inputs: str):
 
5
  import nltk
6
 
7
  class PreTrainedPipeline():
8
+
9
  def __init__(self, path=""):
10
  # IMPLEMENT_THIS
11
  # Preload all the elements you are going to need at inference.
 
14
  nltk.download('punkt')
15
  self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
16
  self.tokenizer = AutoTokenizer.from_pretrained(path)
17
+
18
+ self.model_type="t5"
19
+ # self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ self.device = "cpu"
21
 
22
 
23
  def __call__(self, inputs: str):