nightfury commited on
Commit
f41534f
1 Parent(s): 0bf9897

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -68,9 +68,8 @@ print ("torch.backends.mps.is_available: ", torch.backends.mps.is_available())
68
  pipe = pipe.to(device)
69
 
70
  model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
71
- model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device(device)), strict=False) #False
72
-
73
  model.eval().half()
 
74
 
75
  imgRes = 256
76
 
@@ -88,7 +87,8 @@ def predict(radio, dict, word_mask, prompt=""):
88
  elif(radio == "type what to keep"):
89
  img = transform(dict["image"]).squeeze(0)
90
  word_masks = [word_mask]
91
- with torch.cuda.amp.autocast(): #with torch.no_grad():
 
92
  preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
93
  init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
94
  filename = f"{uuid.uuid4()}.png"
@@ -102,7 +102,8 @@ def predict(radio, dict, word_mask, prompt=""):
102
  else:
103
  img = transform(dict["image"]).unsqueeze(0)
104
  word_masks = [word_mask]
105
- with torch.cuda.amp.autocast(): #with torch.no_grad():
 
106
  preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
107
  init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
108
  filename = f"{uuid.uuid4()}.png"
68
  pipe = pipe.to(device)
69
 
70
  model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
 
 
71
  model.eval().half()
72
+ model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device(device)), strict=False) #False
73
 
74
  imgRes = 256
75
 
87
  elif(radio == "type what to keep"):
88
  img = transform(dict["image"]).squeeze(0)
89
  word_masks = [word_mask]
90
+ with torch.no_grad():
91
+ #torch.cuda.amp.autocast(): #
92
  preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
93
  init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
94
  filename = f"{uuid.uuid4()}.png"
102
  else:
103
  img = transform(dict["image"]).unsqueeze(0)
104
  word_masks = [word_mask]
105
+ #with torch.cuda.amp.autocast(): #
106
+ with torch.no_grad():
107
  preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
108
  init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
109
  filename = f"{uuid.uuid4()}.png"