Sophie98 commited on
Commit
21b3928
1 Parent(s): a37eb28

fix error I think

Browse files
Files changed (1) hide show
  1. styleTransfer.py +24 -23
styleTransfer.py CHANGED
@@ -66,41 +66,42 @@ vgg.eval()
66
 
67
  new_state_dict = OrderedDict()
68
  state_dict = torch.load(decoder_path)
69
- for k, v in state_dict.items():
70
- #namekey = k[7:] # remove `module.`
71
- namekey = k
72
- new_state_dict[namekey] = v
73
- decoder.load_state_dict(new_state_dict)
74
 
75
  new_state_dict = OrderedDict()
76
  state_dict = torch.load(Trans_path)
77
- for k, v in state_dict.items():
78
- #namekey = k[7:] # remove `module.`
79
- namekey = k
80
- new_state_dict[namekey] = v
81
- Trans.load_state_dict(new_state_dict)
82
 
83
  new_state_dict = OrderedDict()
84
  state_dict = torch.load(embedding_path)
85
- for k, v in state_dict.items():
86
- #namekey = k[7:] # remove `module.`
87
- namekey = k
88
- new_state_dict[namekey] = v
89
- embedding.load_state_dict(new_state_dict)
90
 
91
  network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
92
  network.eval()
93
- network.to(device)
94
-
95
- content_tf = test_transform(content_size, crop)
96
- style_tf = test_transform(style_size, crop)
97
 
98
  def StyleTransformer(content_img: Image, style_img: Image):
 
 
 
 
 
99
  content_tf1 = content_transform()
100
- content = content_tf(content_img.convert("RGB"))
101
  h,w,c=np.shape(content)
102
  style_tf1 = style_transform(h,w)
103
- style = style_tf(style_img.convert("RGB"))
104
  style = style.to(device).unsqueeze(0)
105
  content = content.to(device).unsqueeze(0)
106
 
@@ -125,6 +126,6 @@ def StyleGAN(content_image, style_image):
125
 
126
 
127
  def create_styledSofa(sofa:Image, style:Image):
128
- styled_sofa = StyleGAN(sofa,style)
129
- #styled_sofa = StyleTransformer(sofa,style)
130
  return styled_sofa
66
 
67
  new_state_dict = OrderedDict()
68
  state_dict = torch.load(decoder_path)
69
+ # for k, v in state_dict.items():
70
+ # #namekey = k[7:] # remove `module.`
71
+ # namekey = k
72
+ # new_state_dict[namekey] = v
73
+ decoder.load_state_dict(state_dict)
74
 
75
  new_state_dict = OrderedDict()
76
  state_dict = torch.load(Trans_path)
77
+ # for k, v in state_dict.items():
78
+ # #namekey = k[7:] # remove `module.`
79
+ # namekey = k
80
+ # new_state_dict[namekey] = v
81
+ Trans.load_state_dict(state_dict)
82
 
83
  new_state_dict = OrderedDict()
84
  state_dict = torch.load(embedding_path)
85
+ # for k, v in state_dict.items():
86
+ # #namekey = k[7:] # remove `module.`
87
+ # namekey = k
88
+ # new_state_dict[namekey] = v
89
+ embedding.load_state_dict(state_dict)
90
 
91
  network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
92
  network.eval()
 
 
 
 
93
 
94
  def StyleTransformer(content_img: Image, style_img: Image):
95
+
96
+ network.to(device)
97
+ # content_tf = test_transform(content_size, crop)
98
+ # style_tf = test_transform(style_size, crop)
99
+
100
  content_tf1 = content_transform()
101
+ content = content_tf1(content_img.convert("RGB"))
102
  h,w,c=np.shape(content)
103
  style_tf1 = style_transform(h,w)
104
+ style = style_tf1(style_img.convert("RGB"))
105
  style = style.to(device).unsqueeze(0)
106
  content = content.to(device).unsqueeze(0)
107
 
126
 
127
 
128
  def create_styledSofa(sofa:Image, style:Image):
129
+ #styled_sofa = StyleGAN(sofa,style)
130
+ styled_sofa = StyleTransformer(sofa,style)
131
  return styled_sofa