Spaces:
Runtime error
Runtime error
File size: 3,356 Bytes
ab316d0 6cd7a87 ab316d0 6cd7a87 dd213e4 6cd7a87 ab316d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import io
import hvplot.pandas
import numpy as np
import panel as pn
import param
import PIL
import requests
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline
pn.extension(template="bootstrap")
pn.state.template.main_max_width = "690px"
pn.state.template.accent_base_color = "#F08080"
pn.state.template.header_background = "#F08080"
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Model
model_id = "timbrooks/instruct-pix2pix"
if device == "cuda":
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
model_id, torch_dtype=torch.float16
)
else:
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
model_id
)
pipe = pipe.to(device)
def new_image(prompt, image, img_guidance, guidance, steps):
edit = pipe(
prompt,
image=image,
image_guidance_scale=img_guidance,
guidance_scale=guidance,
num_inference_steps=steps,
).images[0]
return edit
# Panel widgets
file_input = pn.widgets.FileInput(width=600)
prompt = pn.widgets.TextInput(
value="", placeholder="Enter image editing instruction here...", width=600
)
img_guidance = pn.widgets.DiscreteSlider(
name="Image guidance scale", options=list(np.arange(1, 10.5, 0.5)), value=1.5
)
guidance = pn.widgets.DiscreteSlider(
name="Guidance scale", options=list(np.arange(1, 10.5, 0.5)), value=7
)
steps = pn.widgets.IntSlider(name="Inference Steps", start=1, end=100, step=1, value=20)
run_button = pn.widgets.Button(name="Run!", width=600)
# define global variables to keep track of things
convos = [] # store all panel objects in a list
image = None
filename = None
def normalize_image(value, width):
"""
normalize image to RBG channels and to the same size
"""
b = io.BytesIO(value)
image = PIL.Image.open(b).convert("RGB")
aspect = image.size[1] / image.size[0]
height = int(aspect * width)
return image.resize((width, height), PIL.Image.ANTIALIAS)
def get_conversations(_, img, img_guidance, guidance, steps, width=600):
"""
Get all the conversations in a Panel object
"""
global image, filename
prompt_text = prompt.value
prompt.value = ""
# if the filename changes, open the image again
if filename != file_input.filename:
filename = file_input.filename
image = normalize_image(file_input.value, width)
convos.clear()
if prompt_text:
# generate new image
image = new_image(prompt_text, image, img_guidance, guidance, steps)
convos.append(pn.Row("\U0001F60A", pn.pane.Markdown(prompt_text, width=600)))
convos.append(pn.Row("\U0001F916", image))
return pn.Column(*convos)
# bind widgets to functions
interactive_conversation = pn.bind(
get_conversations, run_button, file_input, img_guidance, guidance, steps
)
interactive_upload = pn.bind(pn.panel, file_input, width=600)
# layout
pn.Column(
pn.pane.Markdown("## \U0001F60A Upload an image file and start editing!"),
pn.Column(file_input, pn.panel(interactive_upload)),
pn.panel(interactive_conversation, loading_indicator=True),
prompt,
pn.Row(run_button),
pn.Card(img_guidance, guidance, steps, width=600, header="Advance settings"),
).servable(title="Stablel Diffusion InstructPix2pix Image Editing Chatbot")
|