TiankaiHang commited on
Commit
6c26e0d
1 Parent(s): 29cd0de
Files changed (1) hide show
  1. app.py +31 -10
app.py CHANGED
@@ -109,9 +109,11 @@ def predict(
109
  random.seed(seed)
110
  np.random.seed(seed)
111
  torch.manual_seed(seed)
112
- torch.cuda.manual_seed(seed)
113
-
114
- torch.cuda.empty_cache()
 
 
115
 
116
  if isinstance(input_img, str):
117
  if input_img.startswith("http"):
@@ -129,7 +131,10 @@ def predict(
129
  else:
130
  input_image = ImageOps.fit(input_image, (width, height), method=Image.LANCZOS)
131
  input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
132
- input_image = rearrange(input_image, "h w c -> 1 c h w").cuda()
 
 
 
133
 
134
  # if PIL Image
135
  elif isinstance(input_img, Image.Image):
@@ -144,7 +149,10 @@ def predict(
144
  else:
145
  input_image = ImageOps.fit(input_image, (width, height), method=Image.LANCZOS)
146
  input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
147
- input_image = rearrange(input_image, "h w c -> 1 c h w").cuda()
 
 
 
148
  elif isinstance(input_img, dict):
149
  input_image = input_img["image"].convert("RGB")
150
  width, height = input_image.size
@@ -158,26 +166,36 @@ def predict(
158
  else:
159
  input_image = ImageOps.fit(input_image, (width, height), method=Image.LANCZOS)
160
  input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
161
- input_image = rearrange(input_image, "h w c -> 1 c h w").cuda()
 
 
 
162
 
163
  assert input_image is not None
164
  # print input image size
165
  print(input_image.shape, factor, width, height)
166
 
167
- with torch.no_grad(), autocast("cuda"):
 
168
  cond = {}
169
  cond["c_crossattn"] = [model.get_learned_conditioning([edit])]
170
  cond["c_concat"] = [model.encode_first_stage(input_image).mode()]
171
 
172
  uncond = {}
173
  if "txt_embed" in additional:
174
- uncond["c_crossattn"] = [additional["txt_embed"].cuda().unsqueeze(0)]
 
 
 
175
  else:
176
  uncond["c_crossattn"] = [null_token]
177
  if "img_embed" in additional:
178
  # uncond["c_concat"] = [additional["img_embed"].cuda()]
179
  # resize to cond["c_concat"][0]
180
- uncond["c_concat"] = [additional["img_embed"].cuda()]
 
 
 
181
  uncond["c_concat"][0] = F.interpolate(uncond["c_concat"][0], size=cond["c_concat"][0].shape[-2:], mode="bilinear", align_corners=False)
182
  else:
183
  uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
@@ -269,7 +287,10 @@ def main(ckpt="checkpoints/v1-5-pruned-emaonly-adaption-task-humanalign.ckpt", a
269
 
270
  vae_ckpt = None
271
  model = load_model_from_config(config, ckpt, vae_ckpt)
272
- model.eval().cuda()
 
 
 
273
 
274
  model_wrap = K.external.CompVisDenoiser(model)
275
  model_wrap_cfg = CFGDenoiser(model_wrap)
 
109
  random.seed(seed)
110
  np.random.seed(seed)
111
  torch.manual_seed(seed)
112
+ try:
113
+ torch.cuda.manual_seed(seed)
114
+ torch.cuda.empty_cache()
115
+ except:
116
+ pass
117
 
118
  if isinstance(input_img, str):
119
  if input_img.startswith("http"):
 
131
  else:
132
  input_image = ImageOps.fit(input_image, (width, height), method=Image.LANCZOS)
133
  input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
134
+ if torch.cuda.is_available():
135
+ input_image = rearrange(input_image, "h w c -> 1 c h w").cuda()
136
+ else:
137
+ input_image = rearrange(input_image, "h w c -> 1 c h w")
138
 
139
  # if PIL Image
140
  elif isinstance(input_img, Image.Image):
 
149
  else:
150
  input_image = ImageOps.fit(input_image, (width, height), method=Image.LANCZOS)
151
  input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
152
+ if torch.cuda.is_available():
153
+ input_image = rearrange(input_image, "h w c -> 1 c h w").cuda()
154
+ else:
155
+ input_image = rearrange(input_image, "h w c -> 1 c h w")
156
  elif isinstance(input_img, dict):
157
  input_image = input_img["image"].convert("RGB")
158
  width, height = input_image.size
 
166
  else:
167
  input_image = ImageOps.fit(input_image, (width, height), method=Image.LANCZOS)
168
  input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
169
+ if torch.cuda.is_available():
170
+ input_image = rearrange(input_image, "h w c -> 1 c h w").cuda()
171
+ else:
172
+ input_image = rearrange(input_image, "h w c -> 1 c h w")
173
 
174
  assert input_image is not None
175
  # print input image size
176
  print(input_image.shape, factor, width, height)
177
 
178
+ # with torch.no_grad(), autocast("cuda"):
179
+ with torch.no_grad():
180
  cond = {}
181
  cond["c_crossattn"] = [model.get_learned_conditioning([edit])]
182
  cond["c_concat"] = [model.encode_first_stage(input_image).mode()]
183
 
184
  uncond = {}
185
  if "txt_embed" in additional:
186
+ if torch.cuda.is_available():
187
+ uncond["c_crossattn"] = [additional["txt_embed"].cuda().unsqueeze(0)]
188
+ else:
189
+ uncond["c_crossattn"] = [additional["txt_embed"].unsqueeze(0)]
190
  else:
191
  uncond["c_crossattn"] = [null_token]
192
  if "img_embed" in additional:
193
  # uncond["c_concat"] = [additional["img_embed"].cuda()]
194
  # resize to cond["c_concat"][0]
195
+ if torch.cuda.is_available():
196
+ uncond["c_concat"] = [additional["img_embed"].cuda()]
197
+ else:
198
+ uncond["c_concat"] = [additional["img_embed"]]
199
  uncond["c_concat"][0] = F.interpolate(uncond["c_concat"][0], size=cond["c_concat"][0].shape[-2:], mode="bilinear", align_corners=False)
200
  else:
201
  uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
 
287
 
288
  vae_ckpt = None
289
  model = load_model_from_config(config, ckpt, vae_ckpt)
290
+ if torch.cuda.is_available():
291
+ model.eval().cuda()
292
+ else:
293
+ model.eval()
294
 
295
  model_wrap = K.external.CompVisDenoiser(model)
296
  model_wrap_cfg = CFGDenoiser(model_wrap)