drscotthawley commited on
Commit
3f93e88
1 Parent(s): b98fe4a

mods to get ZeroGPU working

Browse files
Files changed (2) hide show
  1. app.py +3 -2
  2. sample.py +1 -1
app.py CHANGED
@@ -42,15 +42,16 @@ def infer_mask_from_init_img(img, mask_with='grey'):
42
  "note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"
43
  assert mask_with in ['blue','white','grey']
44
  "given an image with mask areas marked, extract the mask itself"
 
45
  if not torch.is_tensor(img):
46
  img = ToTensor()(img)
47
- print("img.shape: ", img.shape)
48
  # shape of mask should be img shape without the channel dimension
49
  if len(img.shape) == 3:
50
  mask = torch.zeros(img.shape[-2:])
51
  elif len(img.shape) == 2:
52
  mask = torch.zeros(img.shape)
53
- print("mask.shape: ", mask.shape)
54
  if mask_with == 'white':
55
  mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
56
  elif mask_with == 'blue':
 
42
  "note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"
43
  assert mask_with in ['blue','white','grey']
44
  "given an image with mask areas marked, extract the mask itself"
45
+ print("\n in infer_mask_from_init_img: ")
46
  if not torch.is_tensor(img):
47
  img = ToTensor()(img)
48
+ print(" img.shape: ", img.shape)
49
  # shape of mask should be img shape without the channel dimension
50
  if len(img.shape) == 3:
51
  mask = torch.zeros(img.shape[-2:])
52
  elif len(img.shape) == 2:
53
  mask = torch.zeros(img.shape)
54
+ print(" mask.shape: ", mask.shape)
55
  if mask_with == 'white':
56
  mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
57
  elif mask_with == 'blue':
sample.py CHANGED
@@ -522,7 +522,7 @@ def get_init_image_and_mask(args, device):
522
  init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
523
  return init_image.to(device), init_mask.to(device)
524
 
525
- @spaces.GPU
526
  def main():
527
  global init_image, init_mask
528
  p = argparse.ArgumentParser(description=__doc__,
 
522
  init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
523
  return init_image.to(device), init_mask.to(device)
524
 
525
+ #@spaces.GPU # generates an error
526
  def main():
527
  global init_image, init_mask
528
  p = argparse.ArgumentParser(description=__doc__,