|  | from collections import OrderedDict | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def normalize_activation(x, eps=1e-10): | 
					
						
						|  | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) | 
					
						
						|  | return x / (norm_factor + eps) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): | 
					
						
						|  |  | 
					
						
						|  | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ | 
					
						
						|  | + f'master/lpips/weights/v{version}/{net_type}.pth' | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | old_state_dict = torch.hub.load_state_dict_from_url( | 
					
						
						|  | url, progress=True, | 
					
						
						|  | map_location=None if torch.cuda.is_available() else torch.device('cpu') | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | new_state_dict = OrderedDict() | 
					
						
						|  | for key, val in old_state_dict.items(): | 
					
						
						|  | new_key = key | 
					
						
						|  | new_key = new_key.replace('lin', '') | 
					
						
						|  | new_key = new_key.replace('model.', '') | 
					
						
						|  | new_state_dict[new_key] = val | 
					
						
						|  |  | 
					
						
						|  | return new_state_dict | 
					
						
						|  |  |