mawairon commited on
Commit
3a67180
1 Parent(s): 5a0bdbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -34,7 +34,7 @@ def load_model(model_name: str):
34
  metadata_features = 0
35
  N_UNIQUE_CLASSES = 38
36
 
37
- if model_name == 'gena-bert':
38
  base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
39
  tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
40
 
@@ -80,7 +80,7 @@ def load_model(model_name: str):
80
  model_seq,
81
  new_head
82
  )
83
- weights = torch.load('CNN_1stGEAC_m2_best.pth')
84
  model.load_state_dict(weights)
85
  return model, None
86
 
@@ -112,7 +112,7 @@ def load_model(model_name: str):
112
  new_head
113
  )
114
 
115
- weights = torch.load('NOO_CNN_1stGEAC_m4_16kcw_best.pth')
116
  joined_model.load_state_dict(weights)
117
 
118
  return joined_model, None
 
34
  metadata_features = 0
35
  N_UNIQUE_CLASSES = 38
36
 
37
+ if model_name == 'GENA-Bert':
38
  base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
39
  tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
40
 
 
80
  model_seq,
81
  new_head
82
  )
83
+ weights = torch.load('CNN_1stGEAC_m2_best.pth',map_location=torch.device('cpu'))
84
  model.load_state_dict(weights)
85
  return model, None
86
 
 
112
  new_head
113
  )
114
 
115
+ weights = torch.load('NOO_CNN_1stGEAC_m4_16kcw_best.pth',map_location=torch.device('cpu'))
116
  joined_model.load_state_dict(weights)
117
 
118
  return joined_model, None