PKUWilliamYang commited on
Commit
b5f45ec
1 Parent(s): 4a0d568

Update dualstylegan.py

Browse files
Files changed (1) hide show
  1. dualstylegan.py +46 -9
dualstylegan.py CHANGED
@@ -34,8 +34,9 @@ class Model:
34
  def __init__(self, device: torch.device | str):
35
  self.device = torch.device(device)
36
  self.landmark_model = self._create_dlib_landmark_model()
37
- self.encoder = self._load_encoder()
38
  self.transform = self._create_transform()
 
39
 
40
  self.style_types = [
41
  'cartoon',
@@ -76,7 +77,20 @@ class Model:
76
  model = pSp(opts)
77
  model.to(self.device)
78
  model.eval()
79
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  @staticmethod
82
  def _create_transform() -> Callable:
@@ -111,6 +125,9 @@ class Model:
111
 
112
  def detect_and_align_face(self, image) -> np.ndarray:
113
  image = align_face(filepath=image.name, predictor=self.landmark_model)
 
 
 
114
  return image
115
 
116
  @staticmethod
@@ -123,14 +140,22 @@ class Model:
123
 
124
  @torch.inference_mode()
125
  def reconstruct_face(self,
126
- image: np.ndarray) -> tuple[np.ndarray, torch.Tensor]:
 
 
 
 
 
 
 
 
127
  image = PIL.Image.fromarray(image)
128
  input_data = self.transform(image).unsqueeze(0).to(self.device)
129
- img_rec, instyle = self.encoder(input_data,
130
  randomize_noise=False,
131
  return_latents=True,
132
- z_plus_latent=True,
133
- return_z_plus_latent=True,
134
  resize=False)
135
  img_rec = torch.clamp(img_rec.detach(), -1, 1)
136
  img_rec = self.postprocess(img_rec[0])
@@ -140,6 +165,15 @@ class Model:
140
  def generate(self, style_type: str, style_id: int, structure_weight: float,
141
  color_weight: float, structure_only: bool,
142
  instyle: torch.Tensor) -> np.ndarray:
 
 
 
 
 
 
 
 
 
143
  generator = self.generator_dict[style_type]
144
  exstyles = self.exstyle_dict[style_type]
145
 
@@ -147,15 +181,18 @@ class Model:
147
  stylename = list(exstyles.keys())[style_id]
148
 
149
  latent = torch.tensor(exstyles[stylename]).to(self.device)
150
- if structure_only:
151
  latent[0, 7:18] = instyle[0, 7:18]
152
  exstyle = generator.generator.style(
153
  latent.reshape(latent.shape[0] * latent.shape[1],
154
  latent.shape[2])).reshape(latent.shape)
 
 
155
 
156
  img_gen, _ = generator([instyle],
157
  exstyle,
158
- z_plus_latent=True,
 
159
  truncation=0.7,
160
  truncation_latent=0,
161
  use_res=True,
@@ -163,4 +200,4 @@ class Model:
163
  [color_weight] * 11)
164
  img_gen = torch.clamp(img_gen.detach(), -1, 1)
165
  img_gen = self.postprocess(img_gen[0])
166
- return img_gen
 
34
  def __init__(self, device: torch.device | str):
35
  self.device = torch.device(device)
36
  self.landmark_model = self._create_dlib_landmark_model()
37
+ self.encoder_dict = self._load_encoder()
38
  self.transform = self._create_transform()
39
+ self.encoder_type = 'z+'
40
 
41
  self.style_types = [
42
  'cartoon',
 
77
  model = pSp(opts)
78
  model.to(self.device)
79
  model.eval()
80
+
81
+ ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
82
+ 'models/encoder_wplus.pt')
83
+ ckpt = torch.load(ckpt_path, map_location='cpu')
84
+ opts = ckpt['opts']
85
+ opts['device'] = self.device.type
86
+ opts['checkpoint_path'] = ckpt_path
87
+ opts['output_size'] = 1024
88
+ opts = argparse.Namespace(**opts)
89
+ model2 = pSp(opts)
90
+ model2.to(self.device)
91
+ model2.eval()
92
+
93
+ return {'z+': model, 'w+': model2}
94
 
95
  @staticmethod
96
  def _create_transform() -> Callable:
 
125
 
126
  def detect_and_align_face(self, image) -> np.ndarray:
127
  image = align_face(filepath=image.name, predictor=self.landmark_model)
128
+ x, y = np.random.randint(255), np.random.randint(255)
129
+ r, g, b = image.getpixel((x, y))
130
+ image.putpixel((x, y), (r, g+1, b)) # trick to make sure run reconstruct_face() once any input setting changes
131
  return image
132
 
133
  @staticmethod
 
140
 
141
  @torch.inference_mode()
142
  def reconstruct_face(self,
143
+ image: np.ndarray, encoder_type: str) -> tuple[np.ndarray, torch.Tensor]:
144
+ if encoder_type == 'Z+ encoder (better stylization)':
145
+ self.encoder_type = 'z+'
146
+ z_plus_latent = True
147
+ return_z_plus_latent = True
148
+ else:
149
+ self.encoder_type = 'w+'
150
+ z_plus_latent = False
151
+ return_z_plus_latent = False
152
  image = PIL.Image.fromarray(image)
153
  input_data = self.transform(image).unsqueeze(0).to(self.device)
154
+ img_rec, instyle = self.encoder_dict[self.encoder_type](input_data,
155
  randomize_noise=False,
156
  return_latents=True,
157
+ z_plus_latent=z_plus_latent,
158
+ return_z_plus_latent=return_z_plus_latent,
159
  resize=False)
160
  img_rec = torch.clamp(img_rec.detach(), -1, 1)
161
  img_rec = self.postprocess(img_rec[0])
 
165
  def generate(self, style_type: str, style_id: int, structure_weight: float,
166
  color_weight: float, structure_only: bool,
167
  instyle: torch.Tensor) -> np.ndarray:
168
+
169
+
170
+ if self.encoder_type == 'z+':
171
+ z_plus_latent = True
172
+ input_is_latent = False
173
+ else:
174
+ z_plus_latent = False
175
+ input_is_latent = True
176
+
177
  generator = self.generator_dict[style_type]
178
  exstyles = self.exstyle_dict[style_type]
179
 
 
181
  stylename = list(exstyles.keys())[style_id]
182
 
183
  latent = torch.tensor(exstyles[stylename]).to(self.device)
184
+ if structure_only and self.encoder_type == 'z+':
185
  latent[0, 7:18] = instyle[0, 7:18]
186
  exstyle = generator.generator.style(
187
  latent.reshape(latent.shape[0] * latent.shape[1],
188
  latent.shape[2])).reshape(latent.shape)
189
+ if structure_only and self.encoder_type == 'w+':
190
+ exstyle[:,7:18] = instyle[:,7:18]
191
 
192
  img_gen, _ = generator([instyle],
193
  exstyle,
194
+ input_is_latent=input_is_latent,
195
+ z_plus_latent=z_plus_latent,
196
  truncation=0.7,
197
  truncation_latent=0,
198
  use_res=True,
 
200
  [color_weight] * 11)
201
  img_gen = torch.clamp(img_gen.detach(), -1, 1)
202
  img_gen = self.postprocess(img_gen[0])
203
+ return img_gen