m41w4r3.exe commited on
Commit
18f41a5
1 Parent(s): 1abfe53

add cuda if exists

Browse files
Files changed (1) hide show
  1. load.py +6 -3
load.py CHANGED
@@ -1,6 +1,7 @@
1
  from transformers import GPT2LMHeadModel
2
  from transformers import PreTrainedTokenizerFast
3
  import os
 
4
 
5
 
6
  class LoadModel:
@@ -31,6 +32,8 @@ class LoadModel:
31
  self.path = path
32
  self.device = device
33
  self.revision = revision
 
 
34
 
35
  def load_model_and_tokenizer(self):
36
  model = self.load_model()
@@ -40,11 +43,11 @@ class LoadModel:
40
 
41
  def load_model(self):
42
  if self.revision is None:
43
- model = GPT2LMHeadModel.from_pretrained(self.path, device_map="auto")
44
  else:
45
  model = GPT2LMHeadModel.from_pretrained(
46
- self.path, revision=self.revision, device_map="auto"
47
- )
48
 
49
  return model
50
 
 
1
  from transformers import GPT2LMHeadModel
2
  from transformers import PreTrainedTokenizerFast
3
  import os
4
+ import torch
5
 
6
 
7
  class LoadModel:
 
32
  self.path = path
33
  self.device = device
34
  self.revision = revision
35
+ if torch.cuda.is_available():
36
+ self.device = "cuda"
37
 
38
  def load_model_and_tokenizer(self):
39
  model = self.load_model()
 
43
 
44
  def load_model(self):
45
  if self.revision is None:
46
+ model = GPT2LMHeadModel.from_pretrained(self.path).to(self.device)
47
  else:
48
  model = GPT2LMHeadModel.from_pretrained(
49
+ self.path, revision=self.revision
50
+ ).to(self.device)
51
 
52
  return model
53