selbl commited on
Commit
8920a14
1 Parent(s): 1cbf66f

Update StreamlitModel.py

Browse files
Files changed (1) hide show
  1. StreamlitModel.py +3 -3
StreamlitModel.py CHANGED
@@ -20,8 +20,8 @@ import torch.hub
20
  profanity.load_censor_words()
21
 
22
  #It seems the streamlit space does not allow mps
23
- #device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_built() else 'cpu'
24
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
 
26
  class GPT2_Model(GPT2PreTrainedModel):
27
 
@@ -105,7 +105,7 @@ def TextGeneration(prompt,prof=False,parts=True):
105
  gpt_model = GPT2_Model(configuration).to(device)
106
  #gpt_model.load_state_dict(torch.load('GPT-Trained-Model-Prod.pt'))
107
  state_dict = torch.hub.load_state_dict_from_url(r'https://github.com/Selbl/LyricGeneration/raw/main/GPT-Trained-Model-Prod.pt?download=')
108
- gpt_model.load_state_dict(state_dict)
109
  #Load tokenizer
110
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>')
111
  #Call loader function
 
20
  profanity.load_censor_words()
21
 
22
  #It seems the streamlit space does not allow mps
23
+ device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_built() else 'cpu'
24
+ #device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
 
26
  class GPT2_Model(GPT2PreTrainedModel):
27
 
 
105
  gpt_model = GPT2_Model(configuration).to(device)
106
  #gpt_model.load_state_dict(torch.load('GPT-Trained-Model-Prod.pt'))
107
  state_dict = torch.hub.load_state_dict_from_url(r'https://github.com/Selbl/LyricGeneration/raw/main/GPT-Trained-Model-Prod.pt?download=')
108
+ gpt_model.load_state_dict(state_dict,map_location=device)
109
  #Load tokenizer
110
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>')
111
  #Call loader function