YannisK commited on
Commit
22a02d6
1 Parent(s): cb6bc7f
Files changed (1) hide show
  1. app.py +2 -8
app.py CHANGED
@@ -24,7 +24,7 @@ scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
24
 
25
  device = 'cpu'
26
 
27
- # Load net
28
  state = torch.load('fire.pth', map_location='cpu')
29
  state['net_params']['pretrained'] = None # no need for imagenet pretrained model
30
  net_sfm = fire_network.init_network(**state['net_params']).to(device)
@@ -37,8 +37,7 @@ for name, param in net_sfm.named_parameters():
37
 
38
  state2 = torch.load('fire_imagenet.pth', map_location='cpu')
39
  state2['net_params'] = state['net_params']
40
- state2['state_dict'] += dim_red_params_dict
41
- # state2['net_params'] =
42
  net_imagenet = fire_network.init_network(**state['net_params']).to(device)
43
  net_imagenet.load_state_dict(state2['state_dict']) #, strict=False)
44
 
@@ -51,21 +50,16 @@ transform = transforms.Compose([
51
  # ---------------------------------------
52
 
53
  # class ImgDataset(data.Dataset):
54
-
55
  # def __init__(self, images, imsize):
56
  # self.images = images
57
  # self.imsize = imsize
58
  # self.transform = transforms.Compose([transforms.ToTensor(), \
59
  # transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))])
60
-
61
-
62
  # def __getitem__(self, index):
63
  # img = self.images[index]
64
  # img.thumbnail((self.imsize, self.imsize), Image.Resampling.LANCZOS)
65
  # print('after imresize:', img.size)
66
  # return self.transform(img)
67
-
68
-
69
  # def __len__(self):
70
  # return len(self.images)
71
 
 
24
 
25
  device = 'cpu'
26
 
27
+ # Load nets
28
  state = torch.load('fire.pth', map_location='cpu')
29
  state['net_params']['pretrained'] = None # no need for imagenet pretrained model
30
  net_sfm = fire_network.init_network(**state['net_params']).to(device)
 
37
 
38
  state2 = torch.load('fire_imagenet.pth', map_location='cpu')
39
  state2['net_params'] = state['net_params']
40
+ state2['state_dict'] = dict(state2['state_dict'], **dim_red_params_dict);
 
41
  net_imagenet = fire_network.init_network(**state['net_params']).to(device)
42
  net_imagenet.load_state_dict(state2['state_dict']) #, strict=False)
43
 
 
50
  # ---------------------------------------
51
 
52
  # class ImgDataset(data.Dataset):
 
53
  # def __init__(self, images, imsize):
54
  # self.images = images
55
  # self.imsize = imsize
56
  # self.transform = transforms.Compose([transforms.ToTensor(), \
57
  # transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))])
 
 
58
  # def __getitem__(self, index):
59
  # img = self.images[index]
60
  # img.thumbnail((self.imsize, self.imsize), Image.Resampling.LANCZOS)
61
  # print('after imresize:', img.size)
62
  # return self.transform(img)
 
 
63
  # def __len__(self):
64
  # return len(self.images)
65