wondervictor commited on
Commit
279dcd4
·
verified ·
1 Parent(s): 3cf9b05

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +4 -2
model.py CHANGED
@@ -92,7 +92,7 @@ class Model:
92
  preprocessor_name: str,
93
  ) -> list[PIL.Image.Image]:
94
  self.t5_model.model.to('cuda').to(torch.bfloat16)
95
- self.gpt_model_canny.to('cuda').to(torch.bfloat16)
96
  self.vq_model.to('cuda')
97
  if isinstance(image, np.ndarray):
98
  image = Image.fromarray(image)
@@ -147,6 +147,7 @@ class Model:
147
  top_k=top_k,
148
  top_p=top_p,
149
  sample_logits=True,
 
150
  )
151
  sampling_time = time.time() - t1
152
  print(f"Full sampling takes about {sampling_time:.2f} seconds.")
@@ -183,7 +184,7 @@ class Model:
183
  control_strength: float,
184
  preprocessor_name: str
185
  ) -> list[PIL.Image.Image]:
186
- self.gpt_model_canny.to('cpu')
187
  self.t5_model.model.to(self.device)
188
  self.gpt_model_depth.to(self.device)
189
  self.get_control_depth.model.to(self.device)
@@ -237,6 +238,7 @@ class Model:
237
  top_k=top_k,
238
  top_p=top_p,
239
  sample_logits=True,
 
240
  )
241
  sampling_time = time.time() - t1
242
  print(f"Full sampling takes about {sampling_time:.2f} seconds.")
 
92
  preprocessor_name: str,
93
  ) -> list[PIL.Image.Image]:
94
  self.t5_model.model.to('cuda').to(torch.bfloat16)
95
+ self.gpt_model_edge.to('cuda').to(torch.bfloat16)
96
  self.vq_model.to('cuda')
97
  if isinstance(image, np.ndarray):
98
  image = Image.fromarray(image)
 
147
  top_k=top_k,
148
  top_p=top_p,
149
  sample_logits=True,
150
+ control_strength=control_strength,
151
  )
152
  sampling_time = time.time() - t1
153
  print(f"Full sampling takes about {sampling_time:.2f} seconds.")
 
184
  control_strength: float,
185
  preprocessor_name: str
186
  ) -> list[PIL.Image.Image]:
187
+ self.gpt_model_edge.to('cpu')
188
  self.t5_model.model.to(self.device)
189
  self.gpt_model_depth.to(self.device)
190
  self.get_control_depth.model.to(self.device)
 
238
  top_k=top_k,
239
  top_p=top_p,
240
  sample_logits=True,
241
+ control_strength=control_strength,
242
  )
243
  sampling_time = time.time() - t1
244
  print(f"Full sampling takes about {sampling_time:.2f} seconds.")