added device to QGHandler
Browse files- qg_pipeline.py +1 -1
qg_pipeline.py
CHANGED
@@ -63,7 +63,7 @@ class QGHandler:
|
|
63 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model)
|
64 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
65 |
self.device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
|
66 |
-
self.model.to(device)
|
67 |
|
68 |
def __call__(self, answers, context):
|
69 |
tokenized_inputs = self.preprocess(answers, context)
|
|
|
63 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model)
|
64 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
65 |
self.device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
|
66 |
+
self.model.to(self.device)
|
67 |
|
68 |
def __call__(self, answers, context):
|
69 |
tokenized_inputs = self.preprocess(answers, context)
|