File size: 3,356 Bytes
ab316d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cd7a87
 
 
ab316d0
 
6cd7a87
dd213e4
6cd7a87
 
 
 
 
 
 
 
ab316d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import io

import hvplot.pandas
import numpy as np
import panel as pn
import param
import PIL
import requests
import torch

from diffusers import StableDiffusionInstructPix2PixPipeline

pn.extension(template="bootstrap")
pn.state.template.main_max_width = "690px"
pn.state.template.accent_base_color = "#F08080"
pn.state.template.header_background = "#F08080"

# Set up device 
device = "cuda" if torch.cuda.is_available() else "cpu"

# Model
model_id = "timbrooks/instruct-pix2pix"

if device == "cuda":
    pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
        model_id, torch_dtype=torch.float16
    )
else:
    pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
        model_id
    )
pipe = pipe.to(device)


def new_image(prompt, image, img_guidance, guidance, steps):
    edit = pipe(
        prompt,
        image=image,
        image_guidance_scale=img_guidance,
        guidance_scale=guidance,
        num_inference_steps=steps,
    ).images[0]
    return edit


# Panel widgets
file_input = pn.widgets.FileInput(width=600)
prompt = pn.widgets.TextInput(
    value="", placeholder="Enter image editing instruction here...", width=600
)
img_guidance = pn.widgets.DiscreteSlider(
    name="Image guidance scale", options=list(np.arange(1, 10.5, 0.5)), value=1.5
)
guidance = pn.widgets.DiscreteSlider(
    name="Guidance scale", options=list(np.arange(1, 10.5, 0.5)), value=7
)
steps = pn.widgets.IntSlider(name="Inference Steps", start=1, end=100, step=1, value=20)
run_button = pn.widgets.Button(name="Run!", width=600)


# define global variables to keep track of things
convos = []  # store all panel objects in a list
image = None
filename = None


def normalize_image(value, width):
    """
    normalize image to RBG channels and to the same size
    """
    b = io.BytesIO(value)
    image = PIL.Image.open(b).convert("RGB")
    aspect = image.size[1] / image.size[0]
    height = int(aspect * width)
    return image.resize((width, height), PIL.Image.ANTIALIAS)


def get_conversations(_, img, img_guidance, guidance, steps, width=600):
    """
    Get all the conversations in a Panel object
    """
    global image, filename
    prompt_text = prompt.value
    prompt.value = ""

    # if the filename changes, open the image again
    if filename != file_input.filename:
        filename = file_input.filename
        image = normalize_image(file_input.value, width)
        convos.clear()

    if prompt_text:
        # generate new image
        image = new_image(prompt_text, image, img_guidance, guidance, steps)
        convos.append(pn.Row("\U0001F60A", pn.pane.Markdown(prompt_text, width=600)))
        convos.append(pn.Row("\U0001F916", image))
    return pn.Column(*convos)


# bind widgets to functions
interactive_conversation = pn.bind(
    get_conversations, run_button, file_input, img_guidance, guidance, steps
)
interactive_upload = pn.bind(pn.panel, file_input, width=600)

# layout
pn.Column(
    pn.pane.Markdown("## \U0001F60A Upload an image file and start editing!"),
    pn.Column(file_input, pn.panel(interactive_upload)),
    pn.panel(interactive_conversation, loading_indicator=True),
    prompt,
    pn.Row(run_button),
    pn.Card(img_guidance, guidance, steps, width=600, header="Advance settings"),
).servable(title="Stablel Diffusion InstructPix2pix Image Editing Chatbot")