File size: 4,854 Bytes
5ea557e
8af8710
5ea557e
8af8710
5ea557e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr

from pipeline_rf import RectifiedFlowPipeline

import torch
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import torch.nn.functional as F

from diffusers import StableDiffusionXLImg2ImgPipeline
import time
import copy
import numpy as np

def merge_dW_to_unet(pipe, dW_dict, alpha=1.0):
    _tmp_sd = pipe.unet.state_dict()
    for key in dW_dict.keys():
        _tmp_sd[key] += dW_dict[key] * alpha
    pipe.unet.load_state_dict(_tmp_sd, strict=False)
    return pipe

def get_dW_and_merge(pipe_rf, lora_path='Lykon/dreamshaper-7', save_dW = False, base_sd='runwayml/stable-diffusion-v1-5', alpha=1.0):    
    # get weights of base sd models
    from diffusers import DiffusionPipeline
    _pipe = DiffusionPipeline.from_pretrained(
        base_sd, 
        torch_dtype=torch.float16,
        safety_checker = None,
    )
    sd_state_dict = _pipe.unet.state_dict()
    
    # get weights of the customized sd models, e.g., the aniverse downloaded from civitai.com    
    _pipe = DiffusionPipeline.from_pretrained(
        lora_path, 
        torch_dtype=torch.float16,
        safety_checker = None,
    )
    lora_unet_checkpoint = _pipe.unet.state_dict()
    
    # get the dW
    dW_dict = {}
    for key in lora_unet_checkpoint.keys():
        dW_dict[key] = lora_unet_checkpoint[key] - sd_state_dict[key]
    
    # return and save dW dict
    if save_dW:
        save_name = lora_path.split('/')[-1] + '_dW.pt'
        torch.save(dW_dict, save_name)
        
    pipe_rf = merge_dW_to_unet(pipe_rf, dW_dict=dW_dict, alpha=alpha)
    pipe_rf.vae = _pipe.vae
    pipe_rf.text_encoder = _pipe.text_encoder
    
    return dW_dict



pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
pipe = pipe.to("cuda")

insta_pipe = RectifiedFlowPipeline.from_pretrained("XCLiu/instaflow_0_9B_from_sd_1_5", torch_dtype=torch.float16) 
dW_dict = get_dW_and_merge(insta_pipe, lora_path="Lykon/dreamshaper-7", save_dW=False, alpha=1.0)     
insta_pipe.to("cuda")

global img

@torch.no_grad()
def set_new_latent_and_generate_new_image(seed, prompt, randomize_seed, num_inference_steps=1, guidance_scale=0.0):
    print('Generate with input seed')
    global img
    negative_prompt=""
    if randomize_seed:
        seed = np.random.randint(0, 2**32)
    seed = int(seed)
    num_inference_steps = int(num_inference_steps)
    guidance_scale = float(guidance_scale)
    print(seed, num_inference_steps, guidance_scale)

    t_s = time.time()
    generator = torch.manual_seed(seed)
    images = insta_pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0.0, generator=generator).images 
    inf_time = time.time() - t_s 

    img = copy.copy(np.array(images[0]))

    return images[0], inf_time, seed

@torch.no_grad()
def refine_image_512(prompt):
    print('Refine with SDXL-Refiner (512)')
    global img

    t_s = time.time()
    img = torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2) / 255.0 
    img = img.permute(0, 2, 3, 1).squeeze(0).cpu().numpy()
    new_image = pipe(prompt, image=img).images[0] 
    print('time consumption:', time.time() - t_s) 
    new_image = np.array(new_image) * 1.0 / 255.

    img = copy.copy(new_image)

    return new_image


with gr.Blocks() as gradio_gui:
    gr.Markdown(
    """
    # InstaFlow! One-Step Stable Diffusion with Rectified Flow [[paper]](https://arxiv.org/abs/2309.06380)
    ## This is a demo of one-step InstaFlow-0.9B with [dreamshaper-7](https://huggingface.co/Lykon/dreamshaper-7) (a LoRA that improves image quality) and measures the inference time.  
    """)

    with gr.Row():
        with gr.Column(scale=0.4):
            with gr.Group():
                gr.Markdown("Generation from InstaFlow-0.9B")
                im = gr.Image()
        
        with gr.Column(scale=0.4):
            inference_time_output = gr.Textbox(value='0.0', label='Inference Time with One-Step InstaFlow (Second)')
            seed_input = gr.Textbox(value='101098274', label="Random Seed") 
            randomize_seed = gr.Checkbox(label="Randomly Sample a Random Seed", value=True)
            prompt_input = gr.Textbox(value='A high-resolution photograph of a waterfall in autumn; muted tone', label="Prompt")

            new_image_button = gr.Button(value="One-Step Generation with InstaFlow and the Random Seed")
            new_image_button.click(set_new_latent_and_generate_new_image, inputs=[seed_input, prompt_input, randomize_seed], outputs=[im, inference_time_output, seed_input])

            refine_button_512 = gr.Button(value="Refine One-Step Generation with SDXL Refiner (Resolution: 512)")
            refine_button_512.click(refine_image_512, inputs=[prompt_input], outputs=[im])


gradio_gui.launch()