afrofusion / app.py
netrosec's picture
Create app.py
d00dc43 verified
import gradio as gr
import numpy as np
import torch
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image
from segment_anything import SamPredictor, sam_model_registry
device = "cuda"
sam_checkpoint = "/home/jupyter/diffusers/examples/sam_vit_h_4b8939.pth" # Added missing forward slash at the beginning
model_type = "vit_h"
# Load the model using the function from the registry and pass the checkpoint path
model_fn = sam_model_registry[model_type]
model = model_fn(checkpoint=sam_checkpoint)
# Move the model to the desired device (GPU)
model.to(device)
predictor = SamPredictor(model)
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16,
) # Removed space
pipe = pipe.to(device)
selected_pixels = []
with gr.Blocks() as demo:
with gr.Row():
input_img = gr.Image(label="Input") # Removed space
mask_img = gr.Image(label="Mask") # Corrected "Mas" to "Mask"
output_img = gr.Image(label="Output") # Removed space
with gr.Row():
prompt_text = gr.Textbox(lines=1, label="Prompt") # Removed space
with gr.Row():
submit = gr.Button("Submit")
def generate_mask(image, evt: gr.SelectData):
selected_pixels.append(evt.index) # Removed space
predictor.set_image(image) # Removed space
input_points = np.array(selected_pixels)
input_labels = np.ones(input_points.shape[0])
mask, _, _ = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False
)
# (n, sz, sz)
mask = Image.fromarray(mask[0, :, :]) # Removed space
return mask
def inpaint(image, mask, prompt):
image = Image.fromarray(image) # Removed space
mask = Image.fromarray(mask) # Removed space
image = image.resize((512, 512))
mask = mask.resize((512, 512))
output = pipe(
prompt=prompt,
image=image,
mask_image=mask,
).images[0]
return output
input_img.select(generate_mask, [input_img], [mask_img])
submit.click(inpaint, inputs=[input_img, mask_img, prompt_text], outputs=[output_img])
if __name__ == "__main__":
demo.launch()