JasonTPhillipsJr commited on
Commit
80744c0
·
verified ·
1 Parent(s): 811f100

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -0
app.py CHANGED
@@ -32,7 +32,19 @@ spaBERT_model.load_state_dict(b_model.state_dict(), strict = False)
32
 
33
  pre_trained_model = torch.load(pretrained_model, map_location=torch.device('cpu'))
34
 
 
 
 
 
 
 
 
 
 
35
 
 
 
 
36
 
37
 
38
 
 
32
 
33
  pre_trained_model = torch.load(pretrained_model, map_location=torch.device('cpu'))
34
 
35
+ model_keys = spaBERT_model.state_dict()
36
+ cnt_layers = 0
37
+ for key in model_keys
38
+ if key in pre_trained_model:
39
+ model_keys[key] = pre_trained_model[key]
40
+ cnt_layers += 1
41
+ else:
42
+ print("No weight for", key)
43
+ print(cnt_layers, 'layers loaded')
44
 
45
+ spaBERT_model.load_state_dict(model_keys)
46
+ spaBERT_model.to(device)
47
+ spaBERT_model.eval()
48
 
49
 
50