Sophie98 commited on
Commit
cfa3c46
1 Parent(s): 21b3928

fix error again

Browse files
Files changed (1) hide show
  1. styleTransfer.py +15 -15
styleTransfer.py CHANGED
@@ -66,27 +66,27 @@ vgg.eval()
66
 
67
  new_state_dict = OrderedDict()
68
  state_dict = torch.load(decoder_path)
69
- # for k, v in state_dict.items():
70
- # #namekey = k[7:] # remove `module.`
71
- # namekey = k
72
- # new_state_dict[namekey] = v
73
- decoder.load_state_dict(state_dict)
74
 
75
  new_state_dict = OrderedDict()
76
  state_dict = torch.load(Trans_path)
77
- # for k, v in state_dict.items():
78
- # #namekey = k[7:] # remove `module.`
79
- # namekey = k
80
- # new_state_dict[namekey] = v
81
- Trans.load_state_dict(state_dict)
82
 
83
  new_state_dict = OrderedDict()
84
  state_dict = torch.load(embedding_path)
85
- # for k, v in state_dict.items():
86
- # #namekey = k[7:] # remove `module.`
87
- # namekey = k
88
- # new_state_dict[namekey] = v
89
- embedding.load_state_dict(state_dict)
90
 
91
  network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
92
  network.eval()
 
66
 
67
  new_state_dict = OrderedDict()
68
  state_dict = torch.load(decoder_path)
69
+ for k, v in state_dict.items():
70
+ #namekey = k[7:] # remove `module.`
71
+ namekey = k
72
+ new_state_dict[namekey] = v
73
+ decoder.load_state_dict(new_state_dict)
74
 
75
  new_state_dict = OrderedDict()
76
  state_dict = torch.load(Trans_path)
77
+ for k, v in state_dict.items():
78
+ #namekey = k[7:] # remove `module.`
79
+ namekey = k
80
+ new_state_dict[namekey] = v
81
+ Trans.load_state_dict(new_state_dict)
82
 
83
  new_state_dict = OrderedDict()
84
  state_dict = torch.load(embedding_path)
85
+ for k, v in state_dict.items():
86
+ #namekey = k[7:] # remove `module.`
87
+ namekey = k
88
+ new_state_dict[namekey] = v
89
+ embedding.load_state_dict(new_state_dict)
90
 
91
  network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
92
  network.eval()