Sophie98 commited on
Commit
ab7b996
1 Parent(s): cfa3c46

fix errors

Browse files
Files changed (1) hide show
  1. styleTransfer.py +10 -37
styleTransfer.py CHANGED
@@ -11,8 +11,6 @@ from collections import OrderedDict
11
  import tensorflow_hub as hub
12
  import tensorflow as tf
13
 
14
- from torchvision.utils import save_image
15
-
16
  ############################################# TRANSFORMER ############################################
17
 
18
  vgg_path = 'vgg_normalised.pth'
@@ -20,16 +18,6 @@ decoder_path = 'decoder_iter_160000.pth'
20
  Trans_path = 'transformer_iter_160000.pth'
21
  embedding_path = 'embedding_iter_160000.pth'
22
 
23
- def test_transform(size, crop):
24
- transform_list = []
25
-
26
- if size != 0:
27
- transform_list.append(transforms.Resize(size))
28
- if crop:
29
- transform_list.append(transforms.CenterCrop(size))
30
- transform_list.append(transforms.ToTensor())
31
- transform = transforms.Compose(transform_list)
32
- return transform
33
  def style_transform(h,w):
34
  k = (h,w)
35
  size = int(np.max(k))
@@ -48,7 +36,6 @@ def content_transform():
48
  # Advanced options
49
  content_size=640
50
  style_size=640
51
- crop='store_true'
52
 
53
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
 
@@ -66,42 +53,28 @@ vgg.eval()
66
 
67
  new_state_dict = OrderedDict()
68
  state_dict = torch.load(decoder_path)
69
- for k, v in state_dict.items():
70
- #namekey = k[7:] # remove `module.`
71
- namekey = k
72
- new_state_dict[namekey] = v
73
- decoder.load_state_dict(new_state_dict)
74
 
75
  new_state_dict = OrderedDict()
76
  state_dict = torch.load(Trans_path)
77
- for k, v in state_dict.items():
78
- #namekey = k[7:] # remove `module.`
79
- namekey = k
80
- new_state_dict[namekey] = v
81
- Trans.load_state_dict(new_state_dict)
82
 
83
  new_state_dict = OrderedDict()
84
  state_dict = torch.load(embedding_path)
85
- for k, v in state_dict.items():
86
- #namekey = k[7:] # remove `module.`
87
- namekey = k
88
- new_state_dict[namekey] = v
89
- embedding.load_state_dict(new_state_dict)
90
 
91
  network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
92
  network.eval()
93
 
 
 
 
94
  def StyleTransformer(content_img: Image, style_img: Image):
95
 
96
  network.to(device)
97
- # content_tf = test_transform(content_size, crop)
98
- # style_tf = test_transform(style_size, crop)
99
-
100
- content_tf1 = content_transform()
101
- content = content_tf1(content_img.convert("RGB"))
102
- h,w,c=np.shape(content)
103
- style_tf1 = style_transform(h,w)
104
- style = style_tf1(style_img.convert("RGB"))
105
  style = style.to(device).unsqueeze(0)
106
  content = content.to(device).unsqueeze(0)
107
 
@@ -128,4 +101,4 @@ def StyleGAN(content_image, style_image):
128
  def create_styledSofa(sofa:Image, style:Image):
129
  #styled_sofa = StyleGAN(sofa,style)
130
  styled_sofa = StyleTransformer(sofa,style)
131
- return styled_sofa
11
  import tensorflow_hub as hub
12
  import tensorflow as tf
13
 
 
 
14
  ############################################# TRANSFORMER ############################################
15
 
16
  vgg_path = 'vgg_normalised.pth'
18
  Trans_path = 'transformer_iter_160000.pth'
19
  embedding_path = 'embedding_iter_160000.pth'
20
 
 
 
 
 
 
 
 
 
 
 
21
  def style_transform(h,w):
22
  k = (h,w)
23
  size = int(np.max(k))
36
  # Advanced options
37
  content_size=640
38
  style_size=640
 
39
 
40
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
 
53
 
54
  new_state_dict = OrderedDict()
55
  state_dict = torch.load(decoder_path)
56
+ decoder.load_state_dict(state_dict)
 
 
 
 
57
 
58
  new_state_dict = OrderedDict()
59
  state_dict = torch.load(Trans_path)
60
+ Trans.load_state_dict(state_dict)
 
 
 
 
61
 
62
  new_state_dict = OrderedDict()
63
  state_dict = torch.load(embedding_path)
64
+ embedding.load_state_dict(state_dict)
 
 
 
 
65
 
66
  network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
67
  network.eval()
68
 
69
+ content_tf = content_transform()
70
+ style_tf = style_transform(style_size,style_size)
71
+
72
  def StyleTransformer(content_img: Image, style_img: Image):
73
 
74
  network.to(device)
75
+
76
+ content = content_tf(content_img.convert("RGB"))
77
+ style = style_tf(style_img.convert("RGB"))
 
 
 
 
 
78
  style = style.to(device).unsqueeze(0)
79
  content = content.to(device).unsqueeze(0)
80
 
101
  def create_styledSofa(sofa:Image, style:Image):
102
  #styled_sofa = StyleGAN(sofa,style)
103
  styled_sofa = StyleTransformer(sofa,style)
104
+ return styled_sofa