merve HF staff commited on
Commit
be35f94
1 Parent(s): ffc0c3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -3,11 +3,13 @@ import torch
3
  import numpy as np
4
  from PIL import Image
5
  import matplotlib.pyplot as plt
 
 
6
 
7
- depth_anything = pipeline(task = "depth-estimation", model="nielsr/depth-anything-small", device=0)
8
  checkpoint = "BAAI/seggpt-vit-large"
9
  image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
10
- model = SegGptForImageSegmentation.from_pretrained(checkpoint)
11
 
12
  def infer_seggpt(image_input, image_prompt, mask_prompt):
13
  num_labels = 100
@@ -17,7 +19,7 @@ def infer_seggpt(image_input, image_prompt, mask_prompt):
17
  prompt_masks=mask_prompt,
18
  return_tensors="pt",
19
  num_labels=num_labels
20
- )
21
  with torch.no_grad():
22
  outputs = model(**inputs)
23
 
@@ -38,6 +40,7 @@ def infer_seggpt(image_input, image_prompt, mask_prompt):
38
  plt.savefig("masks.png", bbox_inches='tight', pad_inches=0)
39
  return "masks.png"
40
 
 
41
  def infer(image_input, image_prompt, mask_prompt):
42
  sg_masks = []
43
  mask_prompt = depth_anything(image_prompt)["depth"].convert("RGB")
 
3
  import numpy as np
4
  from PIL import Image
5
  import matplotlib.pyplot as plt
6
+ import spaces
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
 
9
+ depth_anything = pipeline(task = "depth-estimation", model="nielsr/depth-anything-small", device=device)
10
  checkpoint = "BAAI/seggpt-vit-large"
11
  image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
12
+ model = SegGptForImageSegmentation.from_pretrained(checkpoint).to(device)
13
 
14
  def infer_seggpt(image_input, image_prompt, mask_prompt):
15
  num_labels = 100
 
19
  prompt_masks=mask_prompt,
20
  return_tensors="pt",
21
  num_labels=num_labels
22
+ ).to(device)
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
 
 
40
  plt.savefig("masks.png", bbox_inches='tight', pad_inches=0)
41
  return "masks.png"
42
 
43
+ @spaces.GPU
44
  def infer(image_input, image_prompt, mask_prompt):
45
  sg_masks = []
46
  mask_prompt = depth_anything(image_prompt)["depth"].convert("RGB")