Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -73,9 +73,23 @@ transform = transforms.Compose([
|
|
73 |
|
74 |
def predict(radio, dict, word_mask, prompt=""):
|
75 |
if(radio == "draw a mask above"):
|
76 |
-
with autocast("cuda"
|
77 |
init_image = dict["image"].convert("RGB").resize((512, 512))
|
78 |
mask = dict["mask"].convert("RGB").resize((512, 512))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
else:
|
80 |
img = transform(dict["image"]).unsqueeze(0)
|
81 |
word_masks = [word_mask]
|
@@ -90,7 +104,7 @@ def predict(radio, dict, word_mask, prompt=""):
|
|
90 |
cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)
|
91 |
mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
|
92 |
os.remove(filename)
|
93 |
-
with autocast("cuda"
|
94 |
images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
|
95 |
return images[0]
|
96 |
|
@@ -175,7 +189,7 @@ with image_blocks as demo:
|
|
175 |
with gr.Column():
|
176 |
image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload").style(height=400)
|
177 |
with gr.Box(elem_id="mask_radio").style(border=False):
|
178 |
-
radio = gr.Radio(["draw a mask above", "type what to mask below"], value="draw a mask above", show_label=False, interactive=True).style(container=False)
|
179 |
word_mask = gr.Textbox(label = "What to find in your image", interactive=False, elem_id="word_mask", placeholder="Disabled").style(container=False)
|
180 |
prompt = gr.Textbox(label = 'Your prompt (what you want to add in place of what you are removing)')
|
181 |
radio.change(fn=swap_word_mask, inputs=radio, outputs=word_mask,show_progress=False)
|
@@ -192,7 +206,7 @@ with image_blocks as demo:
|
|
192 |
gr.HTML(
|
193 |
"""
|
194 |
<div class="footer">
|
195 |
-
<p>Model by <a href="https://huggingface.co/CompVis" style="text-decoration: underline;" target="_blank">CompVis</a> and <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">Stability AI</a> - Inpainting by <a href="https://github.com/
|
196 |
</p>
|
197 |
</div>
|
198 |
<div class="acknowledgments">
|
|
|
73 |
|
74 |
def predict(radio, dict, word_mask, prompt=""):
|
75 |
if(radio == "draw a mask above"):
|
76 |
+
with autocast(device): #"cuda"
|
77 |
init_image = dict["image"].convert("RGB").resize((512, 512))
|
78 |
mask = dict["mask"].convert("RGB").resize((512, 512))
|
79 |
+
else if(radio == "type what to keep"):
|
80 |
+
img = transform(dict["image"]).squeeze(0)
|
81 |
+
word_masks = [word_mask]
|
82 |
+
with torch.no_grad():
|
83 |
+
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
84 |
+
init_image = dict['image'].convert('RGB').resize((512, 512))
|
85 |
+
filename = f"{uuid.uuid4()}.png"
|
86 |
+
plt.imsave(filename,torch.sigmoid(preds[0][0]))
|
87 |
+
img2 = cv2.imread(filename)
|
88 |
+
gray_image = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
89 |
+
(thresh, bw_image) = cv2.threshold(gray_image, 100, 255, cv2.THRESH_BINARY)
|
90 |
+
cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)
|
91 |
+
mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
|
92 |
+
os.remove(filename)
|
93 |
else:
|
94 |
img = transform(dict["image"]).unsqueeze(0)
|
95 |
word_masks = [word_mask]
|
|
|
104 |
cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)
|
105 |
mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
|
106 |
os.remove(filename)
|
107 |
+
with autocast(device): #"cuda"
|
108 |
images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
|
109 |
return images[0]
|
110 |
|
|
|
189 |
with gr.Column():
|
190 |
image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload").style(height=400)
|
191 |
with gr.Box(elem_id="mask_radio").style(border=False):
|
192 |
+
radio = gr.Radio(["draw a mask above", "type what to mask below", "type what to keep"], value="draw a mask above", show_label=False, interactive=True).style(container=False)
|
193 |
word_mask = gr.Textbox(label = "What to find in your image", interactive=False, elem_id="word_mask", placeholder="Disabled").style(container=False)
|
194 |
prompt = gr.Textbox(label = 'Your prompt (what you want to add in place of what you are removing)')
|
195 |
radio.change(fn=swap_word_mask, inputs=radio, outputs=word_mask,show_progress=False)
|
|
|
206 |
gr.HTML(
|
207 |
"""
|
208 |
<div class="footer">
|
209 |
+
<p>Model by <a href="https://huggingface.co/CompVis" style="text-decoration: underline;" target="_blank">CompVis</a> and <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">Stability AI</a> - Inpainting by <a href="https://github.com/" style="text-decoration: underline;" target="_blank">NightFury</a> using clipseg[model] with bit modification - Gradio Demo on 🤗 Hugging Face
|
210 |
</p>
|
211 |
</div>
|
212 |
<div class="acknowledgments">
|