Michael Gira commited on
Commit
74b3160
1 Parent(s): 659007c

Change device according to environment

Browse files
Files changed (1) hide show
  1. load_model.py +1 -1
load_model.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
5
  from model import get_model
6
 
7
- device = 'cpu'
8
  models_path = 'models'
9
 
10
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
4
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
5
  from model import get_model
6
 
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
  models_path = 'models'
9
 
10
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2')