Sophie98 commited on
Commit
587b848
1 Parent(s): 709b74f

fix error with style projection

Browse files
Files changed (1) hide show
  1. StyleTransfer/styleTransfer.py +4 -3
StyleTransfer/styleTransfer.py CHANGED
@@ -9,7 +9,7 @@ from collections import OrderedDict
9
  import tensorflow_hub as tfhub
10
  import tensorflow as tf
11
  import paddlehub as phub
12
-
13
 
14
  ############################################# TRANSFORMER ############################################
15
 
@@ -87,14 +87,15 @@ def StyleFAST(content_image:Image.Image, style_image:Image.Image) -> Image.Image
87
  return Image.fromarray(np.uint8(stylized_image[0] * 255))
88
 
89
  ########################################### STYLE PROJECTION ##########################################
 
90
  stylepro_artistic = phub.Module(name="stylepro_artistic")
91
  def StyleProjection(content_image:Image.Image,style_image:Image.Image) -> Image.Image:
92
  print('line92')
93
  result = stylepro_artistic.style_transfer(
94
  images=[{
95
  'content': np.array(content_image.convert('RGB') )[:, :, ::-1],
96
- 'styles': [np.array(style_image.convert('RGB') )[:, :, ::-1]]
97
- }])
98
  print('line97')
99
  return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')
100
 
 
9
  import tensorflow_hub as tfhub
10
  import tensorflow as tf
11
  import paddlehub as phub
12
+ import os
13
 
14
  ############################################# TRANSFORMER ############################################
15
 
 
87
  return Image.fromarray(np.uint8(stylized_image[0] * 255))
88
 
89
  ########################################### STYLE PROJECTION ##########################################
90
+ os.system("phub install stylepro_artistic==1.0.1")
91
  stylepro_artistic = phub.Module(name="stylepro_artistic")
92
  def StyleProjection(content_image:Image.Image,style_image:Image.Image) -> Image.Image:
93
  print('line92')
94
  result = stylepro_artistic.style_transfer(
95
  images=[{
96
  'content': np.array(content_image.convert('RGB') )[:, :, ::-1],
97
+ 'styles': [np.array(style_image.convert('RGB') )[:, :, ::-1]]}],
98
+ alpha=0.8)
99
  print('line97')
100
  return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')
101