wondervictor commited on
Commit
48c1bdf
·
verified ·
1 Parent(s): 108d3f4

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +10 -4
model.py CHANGED
@@ -64,10 +64,14 @@ class Model:
64
  return gpt_model
65
 
66
  def load_gpt_weight(self, condition_type='edge'):
 
 
67
  gpt_ckpt = models[condition_type]
68
  model_weight = load_file(gpt_ckpt)
69
  self.gpt_model.load_state_dict(model_weight, strict=False)
70
  self.gpt_model.eval()
 
 
71
  # print("gpt model is loaded")
72
 
73
  def load_t5(self):
@@ -193,10 +197,12 @@ class Model:
193
  control_strength: float,
194
  preprocessor_name: str
195
  ) -> list[PIL.Image.Image]:
196
- self.gpt_model_edge.to('cpu')
197
  self.t5_model.model.to(self.device)
198
- self.gpt_model_depth.to(self.device)
199
- self.get_control_depth.model.to(self.device)
 
 
200
  self.vq_model.to(self.device)
201
  if isinstance(image, np.ndarray):
202
  image = Image.fromarray(image)
@@ -237,7 +243,7 @@ class Model:
237
  qzshape = [len(c_indices), 8, H // 16, W // 16]
238
  t1 = time.time()
239
  index_sample = generate(
240
- self.gpt_model_depth,
241
  c_indices,
242
  (H // 16) * (W // 16),
243
  c_emb_masks,
 
64
  return gpt_model
65
 
66
  def load_gpt_weight(self, condition_type='edge'):
67
+ torch.cuda.empty_cache()
68
+ gc.collect()
69
  gpt_ckpt = models[condition_type]
70
  model_weight = load_file(gpt_ckpt)
71
  self.gpt_model.load_state_dict(model_weight, strict=False)
72
  self.gpt_model.eval()
73
+ torch.cuda.empty_cache()
74
+ gc.collect()
75
  # print("gpt model is loaded")
76
 
77
  def load_t5(self):
 
197
  control_strength: float,
198
  preprocessor_name: str
199
  ) -> list[PIL.Image.Image]:
200
+ # self.gpt_model_edge.to('cpu')
201
  self.t5_model.model.to(self.device)
202
+ # self.gpt_model_depth.to(self.device)
203
+ self.load_gpt_weight('depth')
204
+ self.gpt_model.to('cuda').to(torch.bfloat16)
205
+ # self.get_control_depth.model.to(self.device)
206
  self.vq_model.to(self.device)
207
  if isinstance(image, np.ndarray):
208
  image = Image.fromarray(image)
 
243
  qzshape = [len(c_indices), 8, H // 16, W // 16]
244
  t1 = time.time()
245
  index_sample = generate(
246
+ self.gpt_model,
247
  c_indices,
248
  (H // 16) * (W // 16),
249
  c_emb_masks,