nightfury commited on
Commit
332b5a0
1 Parent(s): 13039d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -0
app.py CHANGED
@@ -103,6 +103,17 @@ def predict(radio, dict, word_mask, prompt=""):
103
  mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
104
  elif(radio == "type what to keep"):
105
  img = transform(dict["image"]).squeeze(0)
 
 
 
 
 
 
 
 
 
 
 
106
  word_masks = [word_mask]
107
  with torch.no_grad():
108
  #torch.cuda.amp.autocast(): #
@@ -118,6 +129,17 @@ def predict(radio, dict, word_mask, prompt=""):
118
  os.remove(filename)
119
  else:
120
  img = transform(dict["image"]).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
121
  word_masks = [word_mask]
122
  #with torch.cuda.amp.autocast(): #
123
  with torch.no_grad():
 
103
  mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
104
  elif(radio == "type what to keep"):
105
  img = transform(dict["image"]).squeeze(0)
106
+
107
+ #-----New Lines-----
108
+ if torch.cuda.is_available():
109
+ img.cuda()
110
+ print ("yes, CUDA is available here !! ")
111
+
112
+ model = model.to(torch.device(device))
113
+ img = img.to(torch.device(device))
114
+ prompt = labels.to(torch.device(device))
115
+ #------------------
116
+
117
  word_masks = [word_mask]
118
  with torch.no_grad():
119
  #torch.cuda.amp.autocast(): #
 
129
  os.remove(filename)
130
  else:
131
  img = transform(dict["image"]).unsqueeze(0)
132
+
133
+ #-----New Lines-----
134
+ if torch.cuda.is_available():
135
+ img.cuda()
136
+ print ("yes, CUDA is available here !! ")
137
+
138
+ model = model.to(torch.device(device))
139
+ img = img.to(torch.device(device))
140
+ prompt = labels.to(torch.device(device))
141
+ #------------------
142
+
143
  word_masks = [word_mask]
144
  #with torch.cuda.amp.autocast(): #
145
  with torch.no_grad():