Sophie98 commited on
Commit
7bebb02
1 Parent(s): 09012f9

forgot to update a file

Browse files
Files changed (1) hide show
  1. StyleTransfer/styleTransfer.py +17 -16
StyleTransfer/styleTransfer.py CHANGED
@@ -13,7 +13,7 @@ import paddlehub as phub
13
 
14
  ############################################# TRANSFORMER ############################################
15
 
16
- def style_transform(h,w):
17
  k = (h,w)
18
  transform_list = []
19
  transform_list.append(transforms.CenterCrop((h,w)))
@@ -21,13 +21,13 @@ def style_transform(h,w):
21
  transform = transforms.Compose(transform_list)
22
  return transform
23
 
24
- def content_transform():
25
  transform_list = []
26
  transform_list.append(transforms.ToTensor())
27
  transform = transforms.Compose(transform_list)
28
  return transform
29
 
30
- def StyleTransformer(content_img: Image, style_img: Image):
31
  vgg_path = 'StyleTransfer/models/vgg_normalised.pth'
32
  decoder_path = 'StyleTransfer/models/decoder_iter_160000.pth'
33
  Trans_path = 'StyleTransfer/models/transformer_iter_160000.pth'
@@ -43,7 +43,6 @@ def StyleTransformer(content_img: Image, style_img: Image):
43
  decoder = StyTR.decoder
44
  Trans = transformer.Transformer()
45
  embedding = StyTR.PatchEmbed()
46
-
47
  decoder.eval()
48
  Trans.eval()
49
  vgg.eval()
@@ -62,7 +61,6 @@ def StyleTransformer(content_img: Image, style_img: Image):
62
 
63
  network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
64
  network.eval()
65
-
66
  content_tf = content_transform()
67
  style_tf = style_transform(style_size,style_size)
68
 
@@ -78,23 +76,20 @@ def StyleTransformer(content_img: Image, style_img: Image):
78
  output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
79
  return Image.fromarray(output)
80
 
81
- ############################################## STYLE-GAN #############################################
82
-
83
 
84
- def StyleGAN(content_image, style_image):
85
  style_transfer_model = tfhub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
86
 
87
- content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255.
88
- style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255.
89
  output = style_transfer_model(content_image, style_image)
90
  stylized_image = output[0]
91
  return Image.fromarray(np.uint8(stylized_image[0] * 255))
92
 
93
  ########################################### STYLE PROJECTION ##########################################
94
 
95
-
96
-
97
- def styleProjection(content_image,style_image):
98
  stylepro_artistic = phub.Module(name="stylepro_artistic")
99
  result = stylepro_artistic.style_transfer(
100
  images=[{
@@ -104,9 +99,15 @@ def styleProjection(content_image,style_image):
104
 
105
  return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')
106
 
107
-
108
- def create_styledSofa(content_image,style_image):
109
- output = StyleTransformer(content_image,style_image)
 
 
 
 
 
 
110
  return output
111
 
112
 
13
 
14
  ############################################# TRANSFORMER ############################################
15
 
16
+ def style_transform(h:int,w:int) -> transforms.Compose:
17
  k = (h,w)
18
  transform_list = []
19
  transform_list.append(transforms.CenterCrop((h,w)))
21
  transform = transforms.Compose(transform_list)
22
  return transform
23
 
24
+ def content_transform() -> transforms.Compose:
25
  transform_list = []
26
  transform_list.append(transforms.ToTensor())
27
  transform = transforms.Compose(transform_list)
28
  return transform
29
 
30
+ def StyleTransformer(content_img: Image.Image, style_img: Image.Image) -> Image.Image:
31
  vgg_path = 'StyleTransfer/models/vgg_normalised.pth'
32
  decoder_path = 'StyleTransfer/models/decoder_iter_160000.pth'
33
  Trans_path = 'StyleTransfer/models/transformer_iter_160000.pth'
43
  decoder = StyTR.decoder
44
  Trans = transformer.Transformer()
45
  embedding = StyTR.PatchEmbed()
 
46
  decoder.eval()
47
  Trans.eval()
48
  vgg.eval()
61
 
62
  network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
63
  network.eval()
 
64
  content_tf = content_transform()
65
  style_tf = style_transform(style_size,style_size)
66
 
76
  output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
77
  return Image.fromarray(output)
78
 
79
+ ############################################## STYLE-FAST #############################################
 
80
 
81
+ def StyleFAST(content_image:Image.Image, style_image:Image.Image) -> Image.Image:
82
  style_transfer_model = tfhub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
83
 
84
+ content_image = tf.convert_to_tensor(np.array(content_image), np.float32)[tf.newaxis, ...] / 255.
85
+ style_image = tf.convert_to_tensor(np.array(style_image), np.float32)[tf.newaxis, ...] / 255.
86
  output = style_transfer_model(content_image, style_image)
87
  stylized_image = output[0]
88
  return Image.fromarray(np.uint8(stylized_image[0] * 255))
89
 
90
  ########################################### STYLE PROJECTION ##########################################
91
 
92
+ def StyleProjection(content_image:Image.Image,style_image:Image.Image) -> Image.Image:
 
 
93
  stylepro_artistic = phub.Module(name="stylepro_artistic")
94
  result = stylepro_artistic.style_transfer(
95
  images=[{
99
 
100
  return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')
101
 
102
+ def create_styledSofa(content_image:Image.Image,style_image:Image.Image,choice:str) -> Image.Image:
103
+ if choice =="Style Transformer":
104
+ output = StyleTransformer(content_image,style_image)
105
+ elif choice =="Style FAST":
106
+ output = StyleFAST(content_image,style_image)
107
+ elif choice =="Style Projection":
108
+ output = StyleProjection(content_image,style_image)
109
+ else:
110
+ output = content_image
111
  return output
112
 
113