100rabhsah commited on
Commit
f422b8d
·
1 Parent(s): 330f66b

app.py analyse function fixed

Browse files
Files changed (1) hide show
  1. src/app.py +8 -1
src/app.py CHANGED
@@ -43,7 +43,14 @@ def load_model():
43
  )
44
 
45
  # Load trained weights
46
- model.load_state_dict(torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu')))
 
 
 
 
 
 
 
47
  model.eval()
48
  st.session_state.model = model
49
 
 
43
  )
44
 
45
  # Load trained weights
46
+ state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu'))
47
+
48
+ # Filter out unexpected keys
49
+ model_state_dict = model.state_dict()
50
+ filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
51
+
52
+ # Load the filtered state dict
53
+ model.load_state_dict(filtered_state_dict, strict=False)
54
  model.eval()
55
  st.session_state.model = model
56