fffiloni's picture
Update app.py
d989292
raw history blame
No virus
1.95 kB
from diffusers import StableDiffusionInpaintPipeline
import gradio as gr
import numpy as np
import imageio
from PIL import Image
from io import BytesIO
import os
MY_SECRET_TOKEN=os.environ.get('HF_TOKEN_SD')
print("hello sylvain")
YOUR_TOKEN=MY_SECRET_TOKEN
device="cpu"
pipe = StableDiffusionInpaintPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=YOUR_TOKEN)
pipe.to(device)
source_img = gr.Image(source="upload", type="numpy", tool="sketch", elem_id="source_container");
gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[2], height="auto")
def resize(height,img):
baseheight = height
img = Image.open(img)
hpercent = (baseheight/float(img.size[1]))
wsize = int((float(img.size[0])*float(hpercent)))
img = img.resize((wsize,baseheight), Image.Resampling.LANCZOS)
return img
def predict(prompt, source_img):
imageio.imwrite("data.png", source_img["image"])
imageio.imwrite("data_mask.png", source_img["mask"])
src = resize(512, "data.png")
src.save("src.png")
mask = resize(512, "data_mask.png")
mask.save("mask.png")
images_list = pipe([prompt] * 2, init_image=src, mask_image=mask, strength=0.75)
images = []
safe_image = Image.open(r"unsafe.png")
for i, image in enumerate(images_list["sample"]):
if(images_list["nsfw_content_detected"][i]):
images.append(safe_image)
else:
images.append(image)
return images
custom_css="style.css"
title="InPainting Stable Diffusion CPU"
description="Inpainting Stable Diffusion example using CPU and HF token. <br />Warning: Slow process... ~5/10 min inference time. <b>NSFW filter enabled.</b><br />Please use 512*512 square image as input to avoid memory error !"
gr.Interface(fn=predict, inputs=["text", source_img], outputs=gallery, css=custom_css, title=title, description=description).launch(enable_queue=True)