Ruining Li commited on
Commit
d04cea4
1 Parent(s): d9fca68

Adapt to HF ZeroGPU

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -98,6 +98,7 @@ def model_init():
98
  @spaces.GPU
99
  def sam_segment(predictor, input_image, drags, foreground_points=None):
100
  image = np.asarray(input_image)
 
101
  predictor.set_image(image)
102
 
103
  with torch.no_grad():
@@ -223,6 +224,7 @@ def single_image_sample(
223
  samples, _ = samples.chunk(2, dim=0)
224
  return samples
225
 
 
226
  def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion, img_cond, seed, cfg_scale, drags_list):
227
  if img_cond is None:
228
  gr.Warning("Please preprocess the image first.")
@@ -266,7 +268,7 @@ def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion,
266
  break
267
 
268
  samples = single_image_sample(
269
- model,
270
  diffusion,
271
  x_cond,
272
  cond_clip_features,
 
98
  @spaces.GPU
99
  def sam_segment(predictor, input_image, drags, foreground_points=None):
100
  image = np.asarray(input_image)
101
+ predictor = predictor.to("cuda")
102
  predictor.set_image(image)
103
 
104
  with torch.no_grad():
 
224
  samples, _ = samples.chunk(2, dim=0)
225
  return samples
226
 
227
+ @spaces.GPU
228
  def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion, img_cond, seed, cfg_scale, drags_list):
229
  if img_cond is None:
230
  gr.Warning("Please preprocess the image first.")
 
268
  break
269
 
270
  samples = single_image_sample(
271
+ model.to("cuda"),
272
  diffusion,
273
  x_cond,
274
  cond_clip_features,