File size: 6,850 Bytes
3df1448 f8ebdc7 3df1448 f8ebdc7 3df1448 f8ebdc7 3df1448 f8ebdc7 3df1448 f8ebdc7 3df1448 f8ebdc7 3df1448 f8ebdc7 3df1448 f8ebdc7 3df1448 f8ebdc7 3df1448 f8ebdc7 3df1448 f8ebdc7 3df1448 f8ebdc7 3df1448 f8ebdc7 3df1448 f8ebdc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import gradio as gr
from io import BytesIO
import requests
import PIL
from PIL import Image
import numpy as np
import os
import uuid
import torch
from torch import autocast
import cv2
from matplotlib import pyplot as plt
from torchvision import transforms
from diffusers import DiffusionPipeline
from diffusers.utils import torch_device
# Load the model
pipe = DiffusionPipeline.from_pretrained(
"Fantasy-Studio/Paint-by-Example",
torch_dtype=torch.float32, # Change to float32 for CPU
)
# Define function to predict
def predict(dict, reference, scale, seed, step):
width, height = dict["image"].size
if width < height:
factor = width / 512.0
width = 512
height = int((height / factor) / 8.0) * 8
else:
factor = height / 512.0
height = 512
width = int((width / factor) / 8.0) * 8
init_image = dict["image"].convert("RGB").resize((width, height))
mask = dict["mask"].convert("RGB").resize((width, height))
generator = torch.Generator().manual_seed(seed) if seed != 0 else None
output = pipe(
image=init_image,
mask_image=mask,
example_image=reference,
generator=generator,
guidance_scale=scale,
num_inference_steps=step,
).images[0]
return output, gr.update(visible=True), gr.update(visible=True), gr.update(
visible=True
)
# Define CSS
css = '''
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
#image_upload{min-height:400px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
#mask_radio .gr-form{background:transparent; border: none}
#word_mask{margin-top: .75em !important}
#word_mask textarea:disabled{opacity: 0.3}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
}
#share-btn {
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
'''
# Read content function
def read_content(file_path: str) -> str:
"""read the content of target file
"""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content
# Define example data
example = {}
ref_dir = 'examples/reference'
image_dir = 'examples/image'
ref_list = [os.path.join(ref_dir, file) for file in os.listdir(ref_dir)]
ref_list.sort()
image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)]
image_list.sort()
# Create Gradio Blocks instance
image_blocks = gr.Blocks(css=css)
with image_blocks as demo:
gr.HTML(read_content("header.html"))
with gr.Group():
with gr.Box():
with gr.Row():
with gr.Column():
image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Source Image")
reference = gr.Image(source='upload', elem_id="image_upload", type="pil", label="Reference Image")
with gr.Column():
image_out = gr.Image(label="Output", elem_id="output-img").style(height=400)
guidance = gr.Slider(label="Guidance scale", value=5, maximum=15, interactive=True)
steps = gr.Slider(label="Steps", value=50, minimum=2, maximum=75, step=1, interactive=True)
seed = gr.Slider(0, 10000, label='Seed (0 = random)', value=0, step=1)
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
btn = gr.Button("Paint!").style(
margin=False,
rounded=(False, True, True, False),
full_width=True,
)
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(community_icon_html, visible=True)
loading_icon = gr.HTML(loading_icon_html, visible=True)
share_button = gr.Button("Share to community", elem_id="share-btn", visible=True)
with gr.Row():
with gr.Column():
gr.Examples(image_list, inputs=[image],label="Examples - Source Image",examples_per_page=12)
with gr.Column():
gr.Examples(ref_list, inputs=[reference],label="Examples - Reference Image",examples_per_page=12)
btn.click(fn=predict, inputs=[image, reference, guidance, seed, steps], outputs=[image_out, community_icon, loading_icon, share_button])
share_button.click(None, [], [], _js=share_js)
gr.HTML(
"""
<div class="footer">
<p>Model by <a href="" style="text-decoration: underline;" target="_blank">Fantasy-Studio</a> - Gradio Demo by 🤗 Hugging Face
</p>
</div>
<div class="acknowledgments">
<p><h4>LICENSE</h4>
The model is licensed with a <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" style="text-decoration: underline;" target="_blank">CreativeML Open RAIL-M</a> license. The authors claim no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in this license. The license forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, spread misinformation and target vulnerable groups. For the full list of restrictions please <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" target="_blank" style="text-decoration: underline;" target="_blank">read the license</a></p>
"""
)
# Launch the Gradio interface
image_blocks.launch()
|