gherget commited on
Commit
da681d6
1 Parent(s): e00031d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -84,9 +84,9 @@ def predict(net, inputs_val, shapes_val, hypar, device):
84
 
85
  inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
86
 
87
- ds_val = net(inputs_val_v)[1] # list of 6 results
88
 
89
- pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
90
 
91
  ## recover the prediction spatial size to the orignal image size
92
  pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
 
84
 
85
  inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
86
 
87
+ ds_val = net(inputs_val_v)[0] # list of 6 results
88
 
89
+ pred_val = ds_val[1][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
90
 
91
  ## recover the prediction spatial size to the orignal image size
92
  pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))