Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import torch | |
from diffusers import StableDiffusionInpaintPipeline | |
from PIL import Image | |
from segment_anything import SamPredictor, sam_model_registry | |
device = "cuda" | |
sam_checkpoint = "/home/jupyter/diffusers/examples/sam_vit_h_4b8939.pth" # Added missing forward slash at the beginning | |
model_type = "vit_h" | |
# Load the model using the function from the registry and pass the checkpoint path | |
model_fn = sam_model_registry[model_type] | |
model = model_fn(checkpoint=sam_checkpoint) | |
# Move the model to the desired device (GPU) | |
model.to(device) | |
predictor = SamPredictor(model) | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-inpainting", | |
torch_dtype=torch.float16, | |
) # Removed space | |
pipe = pipe.to(device) | |
selected_pixels = [] | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
input_img = gr.Image(label="Input") # Removed space | |
mask_img = gr.Image(label="Mask") # Corrected "Mas" to "Mask" | |
output_img = gr.Image(label="Output") # Removed space | |
with gr.Row(): | |
prompt_text = gr.Textbox(lines=1, label="Prompt") # Removed space | |
with gr.Row(): | |
submit = gr.Button("Submit") | |
def generate_mask(image, evt: gr.SelectData): | |
selected_pixels.append(evt.index) # Removed space | |
predictor.set_image(image) # Removed space | |
input_points = np.array(selected_pixels) | |
input_labels = np.ones(input_points.shape[0]) | |
mask, _, _ = predictor.predict( | |
point_coords=input_points, | |
point_labels=input_labels, | |
multimask_output=False | |
) | |
# (n, sz, sz) | |
mask = Image.fromarray(mask[0, :, :]) # Removed space | |
return mask | |
def inpaint(image, mask, prompt): | |
image = Image.fromarray(image) # Removed space | |
mask = Image.fromarray(mask) # Removed space | |
image = image.resize((512, 512)) | |
mask = mask.resize((512, 512)) | |
output = pipe( | |
prompt=prompt, | |
image=image, | |
mask_image=mask, | |
).images[0] | |
return output | |
input_img.select(generate_mask, [input_img], [mask_img]) | |
submit.click(inpaint, inputs=[input_img, mask_img, prompt_text], outputs=[output_img]) | |
if __name__ == "__main__": | |
demo.launch() |