Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	edits
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -29,10 +29,18 @@ state = torch.load('fire.pth', map_location='cpu') 
     | 
|
| 29 | 
         
             
            state['net_params']['pretrained'] = None # no need for imagenet pretrained model
         
     | 
| 30 | 
         
             
            net_sfm = fire_network.init_network(**state['net_params']).to(device)
         
     | 
| 31 | 
         
             
            net_sfm.load_state_dict(state['state_dict'])
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 32 | 
         | 
| 33 | 
         
             
            state2 = torch.load('fire_imagenet.pth', map_location='cpu')
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 34 | 
         
             
            net_imagenet = fire_network.init_network(**state['net_params']).to(device)
         
     | 
| 35 | 
         
            -
            net_imagenet.load_state_dict(state2['state_dict'])
         
     | 
| 36 | 
         | 
| 37 | 
         
             
            # ---------------------------------------
         
     | 
| 38 | 
         
             
            transform = transforms.Compose([
         
     | 
| 
         | 
|
| 29 | 
         
             
            state['net_params']['pretrained'] = None # no need for imagenet pretrained model
         
     | 
| 30 | 
         
             
            net_sfm = fire_network.init_network(**state['net_params']).to(device)
         
     | 
| 31 | 
         
             
            net_sfm.load_state_dict(state['state_dict'])
         
     | 
| 32 | 
         
            +
            dim_red_params_dict = {}
         
     | 
| 33 | 
         
            +
            for name, param in net_sfm.named_parameters():
         
     | 
| 34 | 
         
            +
                if 'dim_reduction' in name:
         
     | 
| 35 | 
         
            +
                    dim_red_params_dict[name] = param
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         | 
| 38 | 
         
             
            state2 = torch.load('fire_imagenet.pth', map_location='cpu')
         
     | 
| 39 | 
         
            +
            state2['net_params'] = state['net_params']
         
     | 
| 40 | 
         
            +
            state2['state_dict'] += dim_red_params_dict
         
     | 
| 41 | 
         
            +
            # state2['net_params'] = 
         
     | 
| 42 | 
         
             
            net_imagenet = fire_network.init_network(**state['net_params']).to(device)
         
     | 
| 43 | 
         
            +
            net_imagenet.load_state_dict(state2['state_dict']) #, strict=False)
         
     | 
| 44 | 
         | 
| 45 | 
         
             
            # ---------------------------------------
         
     | 
| 46 | 
         
             
            transform = transforms.Compose([
         
     |