juxuan27 commited on
Commit
efdf9b6
1 Parent(s): 54eb57c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -4
app.py CHANGED
@@ -15,9 +15,8 @@ import torch
15
  from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
16
  import random
17
  import gradio as gr
18
- import spaces
19
 
20
- mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth')
21
  mobile_sam.eval()
22
  mobile_predictor = SamPredictor(mobile_sam)
23
  colors = [(255, 0, 0), (0, 255, 0)]
@@ -74,7 +73,6 @@ def resize_image(input_image, resolution):
74
  img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
75
  return img
76
 
77
- @spaces.GPU
78
  def process(input_image,
79
  original_image,
80
  original_mask,
@@ -275,7 +273,7 @@ with block:
275
  for p, l in sel_pix:
276
  points.append(p)
277
  labels.append(l)
278
- mobile_predictor=mobile_predictor.to("cuda")
279
  mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img))
280
  with torch.no_grad():
281
  masks, _, _ = mobile_predictor.predict(point_coords=np.array(points), point_labels=np.array(labels), multimask_output=False)
 
15
  from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
16
  import random
17
  import gradio as gr
 
18
 
19
+ mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to("cuda")
20
  mobile_sam.eval()
21
  mobile_predictor = SamPredictor(mobile_sam)
22
  colors = [(255, 0, 0), (0, 255, 0)]
 
73
  img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
74
  return img
75
 
 
76
  def process(input_image,
77
  original_image,
78
  original_mask,
 
273
  for p, l in sel_pix:
274
  points.append(p)
275
  labels.append(l)
276
+ mobile_predictor=mobile_predictor
277
  mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img))
278
  with torch.no_grad():
279
  masks, _, _ = mobile_predictor.predict(point_coords=np.array(points), point_labels=np.array(labels), multimask_output=False)