nightfury commited on
Commit
b61fad9
1 Parent(s): 95bf211

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -4
app.py CHANGED
@@ -73,9 +73,23 @@ transform = transforms.Compose([
73
 
74
  def predict(radio, dict, word_mask, prompt=""):
75
  if(radio == "draw a mask above"):
76
- with autocast("cuda"):
77
  init_image = dict["image"].convert("RGB").resize((512, 512))
78
  mask = dict["mask"].convert("RGB").resize((512, 512))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  else:
80
  img = transform(dict["image"]).unsqueeze(0)
81
  word_masks = [word_mask]
@@ -90,7 +104,7 @@ def predict(radio, dict, word_mask, prompt=""):
90
  cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)
91
  mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
92
  os.remove(filename)
93
- with autocast("cuda"):
94
  images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
95
  return images[0]
96
 
@@ -175,7 +189,7 @@ with image_blocks as demo:
175
  with gr.Column():
176
  image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload").style(height=400)
177
  with gr.Box(elem_id="mask_radio").style(border=False):
178
- radio = gr.Radio(["draw a mask above", "type what to mask below"], value="draw a mask above", show_label=False, interactive=True).style(container=False)
179
  word_mask = gr.Textbox(label = "What to find in your image", interactive=False, elem_id="word_mask", placeholder="Disabled").style(container=False)
180
  prompt = gr.Textbox(label = 'Your prompt (what you want to add in place of what you are removing)')
181
  radio.change(fn=swap_word_mask, inputs=radio, outputs=word_mask,show_progress=False)
@@ -192,7 +206,7 @@ with image_blocks as demo:
192
  gr.HTML(
193
  """
194
  <div class="footer">
195
- <p>Model by <a href="https://huggingface.co/CompVis" style="text-decoration: underline;" target="_blank">CompVis</a> and <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">Stability AI</a> - Inpainting by <a href="https://github.com/nagolinc" style="text-decoration: underline;" target="_blank">nagolinc</a> and <a href="https://github.com/patil-suraj" style="text-decoration: underline;">patil-suraj</a>, inpainting with words by <a href="https://twitter.com/yvrjsharma/" style="text-decoration: underline;" target="_blank">@yvrjsharma</a> and <a href="https://twitter.com/1littlecoder" style="text-decoration: underline;">@1littlecoder</a> - Gradio Demo by 🤗 Hugging Face
196
  </p>
197
  </div>
198
  <div class="acknowledgments">
 
73
 
74
  def predict(radio, dict, word_mask, prompt=""):
75
  if(radio == "draw a mask above"):
76
+ with autocast(device): #"cuda"
77
  init_image = dict["image"].convert("RGB").resize((512, 512))
78
  mask = dict["mask"].convert("RGB").resize((512, 512))
79
+ else if(radio == "type what to keep"):
80
+ img = transform(dict["image"]).squeeze(0)
81
+ word_masks = [word_mask]
82
+ with torch.no_grad():
83
+ preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
84
+ init_image = dict['image'].convert('RGB').resize((512, 512))
85
+ filename = f"{uuid.uuid4()}.png"
86
+ plt.imsave(filename,torch.sigmoid(preds[0][0]))
87
+ img2 = cv2.imread(filename)
88
+ gray_image = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
89
+ (thresh, bw_image) = cv2.threshold(gray_image, 100, 255, cv2.THRESH_BINARY)
90
+ cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)
91
+ mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
92
+ os.remove(filename)
93
  else:
94
  img = transform(dict["image"]).unsqueeze(0)
95
  word_masks = [word_mask]
 
104
  cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)
105
  mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
106
  os.remove(filename)
107
+ with autocast(device): #"cuda"
108
  images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
109
  return images[0]
110
 
 
189
  with gr.Column():
190
  image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload").style(height=400)
191
  with gr.Box(elem_id="mask_radio").style(border=False):
192
+ radio = gr.Radio(["draw a mask above", "type what to mask below", "type what to keep"], value="draw a mask above", show_label=False, interactive=True).style(container=False)
193
  word_mask = gr.Textbox(label = "What to find in your image", interactive=False, elem_id="word_mask", placeholder="Disabled").style(container=False)
194
  prompt = gr.Textbox(label = 'Your prompt (what you want to add in place of what you are removing)')
195
  radio.change(fn=swap_word_mask, inputs=radio, outputs=word_mask,show_progress=False)
 
206
  gr.HTML(
207
  """
208
  <div class="footer">
209
+ <p>Model by <a href="https://huggingface.co/CompVis" style="text-decoration: underline;" target="_blank">CompVis</a> and <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">Stability AI</a> - Inpainting by <a href="https://github.com/" style="text-decoration: underline;" target="_blank">NightFury</a> using clipseg[model] with bit modification - Gradio Demo on 🤗 Hugging Face
210
  </p>
211
  </div>
212
  <div class="acknowledgments">