Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -32,7 +32,19 @@ spaBERT_model.load_state_dict(b_model.state_dict(), strict = False)
|
|
32 |
|
33 |
pre_trained_model = torch.load(pretrained_model, map_location=torch.device('cpu'))
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
|
|
|
|
|
|
36 |
|
37 |
|
38 |
|
|
|
32 |
|
33 |
pre_trained_model = torch.load(pretrained_model, map_location=torch.device('cpu'))
|
34 |
|
35 |
+
model_keys = spaBERT_model.state_dict()
|
36 |
+
cnt_layers = 0
|
37 |
+
for key in model_keys
|
38 |
+
if key in pre_trained_model:
|
39 |
+
model_keys[key] = pre_trained_model[key]
|
40 |
+
cnt_layers += 1
|
41 |
+
else:
|
42 |
+
print("No weight for", key)
|
43 |
+
print(cnt_layers, 'layers loaded')
|
44 |
|
45 |
+
spaBERT_model.load_state_dict(model_keys)
|
46 |
+
spaBERT_model.to(device)
|
47 |
+
spaBERT_model.eval()
|
48 |
|
49 |
|
50 |
|