p-christ commited on
Commit
993071c
1 Parent(s): 1ad20ef

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +2 -2
pipeline.py CHANGED
@@ -3,6 +3,7 @@ from typing import Dict, List, Any
3
  import itertools
4
  from nltk import sent_tokenize
5
  import nltk
 
6
 
7
  class PreTrainedPipeline():
8
 
@@ -16,8 +17,7 @@ class PreTrainedPipeline():
16
  self.tokenizer = AutoTokenizer.from_pretrained(path)
17
 
18
  self.model_type="t5"
19
- # self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
- self.device = "cpu"
21
 
22
 
23
  def __call__(self, inputs: str):
 
3
  import itertools
4
  from nltk import sent_tokenize
5
  import nltk
6
+ import torch
7
 
8
  class PreTrainedPipeline():
9
 
 
17
  self.tokenizer = AutoTokenizer.from_pretrained(path)
18
 
19
  self.model_type="t5"
20
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
21
 
22
 
23
  def __call__(self, inputs: str):