File size: 2,452 Bytes
d6e5ca6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

from diffusers import StableDiffusionInpaintPipeline
import torch

model_id = 'stabilityai/stable-diffusion-2-inpainting'
sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(model_id,torch_dtype = torch.float16)
sd_pipeline = sd_pipeline.to("cuda")

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to("cuda")

predictor = SamPredictor(sam)

import gradio as gr
import numpy as np
from PIL import Image

selected_pixels = []
isInvert = 0

with gr.Blocks() as genaieg:
  selected_pixels = []
  isInvert = 0
  with gr.Row():
    input_img = gr.Image(label = 'Input')
    mask_img = gr.Image(label = "Mask")
  with gr.Row():
    output_img = gr.Image(label = "Ouput")

  def invertmask():
    global isInvert
    isInvert = not(isInvert)

  with gr.Row():
    prompt_text = gr.Textbox(line = 1,label = 'Prompt')
    submit = gr.Button('Submit')
    radio = gr.Radio(['Invert Mask'])
    radio.select(fn = invertmask)

  def generate_mask(image, evt: gr.SelectData):
    selected_pixels.append(evt.index)
    predictor.set_image(image)
    input_points = np.array(selected_pixels)
    input_label = np.ones(input_points.shape[0])
    mask, _, _ = predictor.predict(
        point_coords = input_points,
        point_labels = input_label,
        multimask_output = False
    )
    if isInvert:
      mask = np.logical_not(mask)
    mask = Image.fromarray(mask[0,:,:])
    return mask


  def inpaint(img, mask, prompt):
    img = Image.fromarray(img)
    mask = Image.fromarray(mask)
    img = img.resize((512,512))
    mask = mask.resize((512,512))
    negative_prompts = """
    duplicate,low quality, lowest quality, bad shape,bad anatomy,
    bad proportions, lowres,error,watermark,username,artistname,
    signature,text,jpeg artifacts,blurry,more than one person,simple background
    """
    prompt_text = "Realistic professinal Headshot of a man for a profile pic" + prompt
    output = sd_pipeline(prompt = prompt_text,
                         image = img,
                         negative_prompt = negative_prompts,
                         mask_image = mask).images[0]

    return output


  input_img.select(generate_mask, [input_img],[mask_img])
  submit.click(inpaint,
               inputs=[input_img,mask_img,prompt_text],
               outputs = [output_img])

if __name__ == '__main__':
  genaieg.launch(debug = True)