Spaces:
Build error
Build error
edits
Browse files
app.py
CHANGED
@@ -24,7 +24,7 @@ scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
|
|
24 |
|
25 |
device = 'cpu'
|
26 |
|
27 |
-
# Load
|
28 |
state = torch.load('fire.pth', map_location='cpu')
|
29 |
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
|
30 |
net_sfm = fire_network.init_network(**state['net_params']).to(device)
|
@@ -37,8 +37,7 @@ for name, param in net_sfm.named_parameters():
|
|
37 |
|
38 |
state2 = torch.load('fire_imagenet.pth', map_location='cpu')
|
39 |
state2['net_params'] = state['net_params']
|
40 |
-
state2['state_dict']
|
41 |
-
# state2['net_params'] =
|
42 |
net_imagenet = fire_network.init_network(**state['net_params']).to(device)
|
43 |
net_imagenet.load_state_dict(state2['state_dict']) #, strict=False)
|
44 |
|
@@ -51,21 +50,16 @@ transform = transforms.Compose([
|
|
51 |
# ---------------------------------------
|
52 |
|
53 |
# class ImgDataset(data.Dataset):
|
54 |
-
|
55 |
# def __init__(self, images, imsize):
|
56 |
# self.images = images
|
57 |
# self.imsize = imsize
|
58 |
# self.transform = transforms.Compose([transforms.ToTensor(), \
|
59 |
# transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))])
|
60 |
-
|
61 |
-
|
62 |
# def __getitem__(self, index):
|
63 |
# img = self.images[index]
|
64 |
# img.thumbnail((self.imsize, self.imsize), Image.Resampling.LANCZOS)
|
65 |
# print('after imresize:', img.size)
|
66 |
# return self.transform(img)
|
67 |
-
|
68 |
-
|
69 |
# def __len__(self):
|
70 |
# return len(self.images)
|
71 |
|
|
|
24 |
|
25 |
device = 'cpu'
|
26 |
|
27 |
+
# Load nets
|
28 |
state = torch.load('fire.pth', map_location='cpu')
|
29 |
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
|
30 |
net_sfm = fire_network.init_network(**state['net_params']).to(device)
|
|
|
37 |
|
38 |
state2 = torch.load('fire_imagenet.pth', map_location='cpu')
|
39 |
state2['net_params'] = state['net_params']
|
40 |
+
state2['state_dict'] = dict(state2['state_dict'], **dim_red_params_dict);
|
|
|
41 |
net_imagenet = fire_network.init_network(**state['net_params']).to(device)
|
42 |
net_imagenet.load_state_dict(state2['state_dict']) #, strict=False)
|
43 |
|
|
|
50 |
# ---------------------------------------
|
51 |
|
52 |
# class ImgDataset(data.Dataset):
|
|
|
53 |
# def __init__(self, images, imsize):
|
54 |
# self.images = images
|
55 |
# self.imsize = imsize
|
56 |
# self.transform = transforms.Compose([transforms.ToTensor(), \
|
57 |
# transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))])
|
|
|
|
|
58 |
# def __getitem__(self, index):
|
59 |
# img = self.images[index]
|
60 |
# img.thumbnail((self.imsize, self.imsize), Image.Resampling.LANCZOS)
|
61 |
# print('after imresize:', img.size)
|
62 |
# return self.transform(img)
|
|
|
|
|
63 |
# def __len__(self):
|
64 |
# return len(self.images)
|
65 |
|