Spaces:
Sleeping
Sleeping
''' | |
from diffusers import utils | |
from diffusers.utils import deprecation_utils | |
from diffusers.models import cross_attention | |
utils.deprecate = lambda *arg, **kwargs: None | |
deprecation_utils.deprecate = lambda *arg, **kwargs: None | |
cross_attention.deprecate = lambda *arg, **kwargs: None | |
''' | |
import os | |
import sys | |
''' | |
MAIN_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) | |
sys.path.insert(0, MAIN_DIR) | |
os.chdir(MAIN_DIR) | |
''' | |
import gradio as gr | |
import numpy as np | |
import torch | |
import random | |
from annotator.util import resize_image, HWC3 | |
from annotator.canny import CannyDetector | |
from diffusers.models.unet_2d_condition import UNet2DConditionModel | |
from diffusers.pipelines import DiffusionPipeline | |
from diffusers.schedulers import DPMSolverMultistepScheduler | |
#from models import ControlLoRA, ControlLoRACrossAttnProcessor | |
apply_canny = CannyDetector() | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
''' | |
pipeline = DiffusionPipeline.from_pretrained( | |
'IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1', safety_checker=None | |
) | |
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) | |
pipeline = pipeline.to(device) | |
unet: UNet2DConditionModel = pipeline.unet | |
#ckpt_path = "ckpts/sd-diffusiondb-canny-model-control-lora-zh" | |
ckpt_path = "svjack/canny-control-lora-zh" | |
control_lora = ControlLoRA.from_pretrained(ckpt_path) | |
control_lora = control_lora.to(device) | |
# load control lora attention processors | |
lora_attn_procs = {} | |
lora_layers_list = list([list(layer_list) for layer_list in control_lora.lora_layers]) | |
n_ch = len(unet.config.block_out_channels) | |
control_ids = [i for i in range(n_ch)] | |
for name in pipeline.unet.attn_processors.keys(): | |
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | |
if name.startswith("mid_block"): | |
control_id = control_ids[-1] | |
elif name.startswith("up_blocks"): | |
block_id = int(name[len("up_blocks.")]) | |
control_id = list(reversed(control_ids))[block_id] | |
elif name.startswith("down_blocks"): | |
block_id = int(name[len("down_blocks.")]) | |
control_id = control_ids[block_id] | |
lora_layers = lora_layers_list[control_id] | |
if len(lora_layers) != 0: | |
lora_layer: ControlLoRACrossAttnProcessor = lora_layers.pop(0) | |
lora_attn_procs[name] = lora_layer | |
unet.set_attn_processor(lora_attn_procs) | |
''' | |
from diffusers import ( | |
AutoencoderKL, | |
ControlNetModel, | |
DDPMScheduler, | |
StableDiffusionControlNetPipeline, | |
UNet2DConditionModel, | |
UniPCMultistepScheduler, | |
) | |
import torch | |
from diffusers.utils import load_image | |
controlnet_model_name_or_path = "svjack/ControlNet-Canny-Zh" | |
controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path) | |
base_model_path = "IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1" | |
pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
base_model_path, controlnet=controlnet, | |
#torch_dtype=torch.float16 | |
) | |
# speed up diffusion process with faster scheduler and memory optimization | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
#pipe.enable_model_cpu_offload() | |
if device == "cuda": | |
pipe = pipe.to("cuda") | |
pipe.safety_checker = None | |
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold): | |
from PIL import Image | |
with torch.no_grad(): | |
img = resize_image(HWC3(input_image), image_resolution) | |
H, W, C = img.shape | |
detected_map = apply_canny(img, low_threshold, high_threshold) | |
detected_map = HWC3(detected_map) | |
''' | |
print(type(detected_map)) | |
return [detected_map] | |
control = torch.from_numpy(detected_map[...,::-1].copy().transpose([2,0,1])).float().to(device)[None] / 127.5 - 1 | |
_ = control_lora(control).control_states | |
if seed == -1: | |
seed = random.randint(0, 65535) | |
''' | |
if seed == -1: | |
seed = random.randint(0, 65535) | |
control_image = Image.fromarray(detected_map) | |
# run inference | |
generator = torch.Generator(device=device).manual_seed(seed) | |
images = [] | |
for i in range(num_samples): | |
''' | |
_ = control_lora(control).control_states | |
image = pipeline( | |
prompt + ', ' + a_prompt, negative_prompt=n_prompt, | |
num_inference_steps=sample_steps, guidance_scale=scale, eta=eta, | |
generator=generator, height=H, width=W).images[0] | |
''' | |
image = pipe( | |
prompt + ', ' + a_prompt, negative_prompt=n_prompt, | |
num_inference_steps=sample_steps, guidance_scale=scale, eta=eta, | |
image = control_image, | |
generator=generator, height=H, width=W).images[0] | |
images.append(np.asarray(image)) | |
results = images | |
return [255 - detected_map] + results | |
block = gr.Blocks().queue() | |
with block: | |
with gr.Row(): | |
gr.Markdown("## Control Stable Diffusion with Canny Edge Maps") | |
#gr.Markdown("This _example_ was **drive** from <br/><b><h4>[https://github.com/svjack/ControlLoRA-Chinese](https://github.com/svjack/ControlLoRA-Chinese)</h4></b>\n") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(source='upload', type="numpy", value = "hate_dog.png") | |
prompt = gr.Textbox(label="Prompt", value = "可爱的狗宝宝") | |
run_button = gr.Button(label="Run") | |
with gr.Accordion("Advanced options", open=False): | |
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) | |
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256) | |
low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1) | |
high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1) | |
sample_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) | |
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) | |
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) | |
eta = gr.Number(label="eta", value=0.0) | |
a_prompt = gr.Textbox(label="Added Prompt", value='') | |
n_prompt = gr.Textbox(label="Negative Prompt", | |
value='低质量,模糊,混乱') | |
with gr.Column(): | |
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') | |
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold] | |
run_button.click(fn=process, inputs=ips, outputs=[result_gallery], show_progress = True) | |
block.launch(server_name='0.0.0.0') | |
#### block.launch(server_name='172.16.202.228', share=True) | |