Spaces:
Runtime error
Runtime error
import sys | |
sys.path.insert(0,'stable_diffusion') | |
import gradio as gr | |
from train_esd import train_esd | |
from convertModels import convert_ldm_unet_checkpoint, create_unet_diffusers_config | |
from omegaconf import OmegaConf | |
from StableDiffuser import StableDiffuser | |
from diffusers import UNet2DConditionModel | |
ckpt_path = "stable_diffusion/models/ldm/sd-v1-4-full-ema.ckpt" | |
config_path = "stable_diffusion/configs/stable-diffusion/v1-inference.yaml" | |
diffusers_config_path = "stable_diffusion/config.json" | |
class Demo: | |
def __init__(self) -> None: | |
self.training = False | |
self.generating = False | |
with gr.Blocks() as demo: | |
self.layout() | |
demo.queue(concurrency_count=10).launch() | |
def disable(self): | |
return [gr.update(interactive=False), gr.update(interactive=False)] | |
def layout(self): | |
with gr.Row(): | |
with gr.Column(scale=1) as training_column: | |
self.prompt_input = gr.Text( | |
placeholder="Enter prompt...", | |
label="Prompt to Erase", | |
info="Prompt corresponding to concept to erase" | |
) | |
self.train_method_input = gr.Dropdown( | |
choices=['noxattn', 'selfattn', 'xattn', 'full'], | |
value='xattn', | |
label='Train Method', | |
info='Method of training' | |
) | |
self.neg_guidance_input = gr.Number( | |
value=1, | |
label="Negative Guidance", | |
info='Guidance of negative training used to train' | |
) | |
self.iterations_input = gr.Number( | |
value=1000, | |
precision=0, | |
label="Iterations", | |
info='iterations used to train' | |
) | |
self.lr_input = gr.Number( | |
value=1e-5, | |
label="Learning Rate", | |
info='Learning rate used to train' | |
) | |
self.progress_bar = gr.Text(interactive=False, label="Training Progress") | |
self.train_button = gr.Button( | |
value="Train", | |
) | |
with gr.Column(scale=2) as inference_column: | |
with gr.Row(): | |
with gr.Column(scale=4): | |
self.prompt_input_infr = gr.Text( | |
placeholder="Enter prompt...", | |
label="Prompt", | |
info="Prompt to generate" | |
) | |
with gr.Column(scale=1): | |
self.seed_infr = gr.Number( | |
label="Seed", | |
value=42 | |
) | |
with gr.Row(): | |
self.image_new = gr.Image( | |
label="New Image", | |
interactive=False | |
) | |
self.image_orig = gr.Image( | |
label="Orig Image", | |
interactive=False | |
) | |
with gr.Row(): | |
self.infr_button = gr.Button( | |
value="Generate", | |
interactive=False | |
) | |
self.infr_button.click(self.inference, inputs = [ | |
self.prompt_input_infr, | |
self.seed_infr | |
], | |
outputs=[ | |
self.image_new, | |
self.image_orig | |
] | |
) | |
self.train_button.click(self.disable, | |
outputs=[self.train_button, self.infr_button] | |
) | |
self.train_button.click(self.train, inputs = [ | |
self.prompt_input, | |
self.train_method_input, | |
self.neg_guidance_input, | |
self.iterations_input, | |
self.lr_input | |
], | |
outputs=[self.train_button, self.infr_button, self.progress_bar] | |
) | |
def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)): | |
if self.training: | |
return [None, None, None] | |
else: | |
self.training = True | |
model_orig, model_edited = train_esd(prompt, | |
train_method, | |
3, | |
neg_guidance, | |
iterations, | |
lr, | |
config_path, | |
ckpt_path, | |
diffusers_config_path, | |
['cuda', 'cuda'] | |
) | |
original_config = OmegaConf.load(config_path) | |
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = 4 | |
unet_config = create_unet_diffusers_config(original_config, image_size=512) | |
model_edited_sd = convert_ldm_unet_checkpoint(model_edited.state_dict(), unet_config) | |
model_orig_sd = convert_ldm_unet_checkpoint(model_orig.state_dict(), unet_config) | |
self.init_inference(model_edited_sd, model_orig_sd, unet_config) | |
return [gr.update(interactive=True), gr.update(interactive=True), None] | |
def init_inference(self, model_edited_sd, model_orig_sd, unet_config): | |
self.model_edited_sd = model_edited_sd | |
self.model_orig_sd = model_orig_sd | |
self.diffuser = StableDiffuser(42) | |
self.diffuser.unet = UNet2DConditionModel(**unet_config) | |
self.diffuser.to('cuda') | |
self.training = False | |
def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)): | |
if self.generating: | |
return [None, None] | |
else: | |
self.generating = True | |
self.diffuser.unet.load_state_dict(self.model_orig_sd) | |
self.diffuser._seed = seed | |
images = self.diffuser( | |
prompt, | |
n_steps=50, | |
reseed=True | |
) | |
orig_image = images[0][0] | |
self.diffuser.unet.load_state_dict(self.model_edited_sd) | |
images = self.diffuser( | |
prompt, | |
n_steps=50, | |
reseed=True | |
) | |
edited_image = images[0][0] | |
self.generating = False | |
return edited_image, orig_image | |
demo = Demo() | |