jamino30 commited on
Commit
289faff
1 Parent(s): 2636ace

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. u2net/evaluate.py +1 -1
  3. u2net/inference.py +1 -2
app.py CHANGED
@@ -22,7 +22,7 @@ print('DEVICE:', device)
22
  if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name())
23
 
24
  def load_model_without_module(model, model_path):
25
- state_dict = torch.load(model_path, map_location=device, weights_only=True)
26
 
27
  new_state_dict = {}
28
  for k, v in state_dict.items():
 
22
  if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name())
23
 
24
  def load_model_without_module(model, model_path):
25
+ state_dict = torch.load(model_path, map_location=device, weights_only=False)
26
 
27
  new_state_dict = {}
28
  for k, v in state_dict.items():
u2net/evaluate.py CHANGED
@@ -11,7 +11,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
  print('Device:', device)
12
 
13
  def load_model(model, model_path):
14
- state_dict = torch.load(model_path, map_location=device, weights_only=True)
15
  model.load_state_dict(state_dict)
16
  model.eval()
17
 
 
11
  print('Device:', device)
12
 
13
  def load_model(model, model_path):
14
+ state_dict = torch.load(model_path, map_location=device, weights_only=False)
15
  model.load_state_dict(state_dict)
16
  model.eval()
17
 
u2net/inference.py CHANGED
@@ -51,8 +51,7 @@ if __name__ == '__main__':
51
  # ---
52
  model = U2Net().to(device)
53
  model = nn.DataParallel(model)
54
- state_dict = torch.load(model_path, map_location=device, weights_only=True)
55
- model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
56
  model.eval()
57
 
58
  mask = run_inference(model, image_path, threshold=None)
 
51
  # ---
52
  model = U2Net().to(device)
53
  model = nn.DataParallel(model)
54
+ model.load_state_dict(torch.load(model_path, map_location=device, weights_only=False))
 
55
  model.eval()
56
 
57
  mask = run_inference(model, image_path, threshold=None)