Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import gradio as gr | |
import torch | |
from finetuning import FineTunedModel | |
from StableDiffuser import StableDiffuser | |
from tqdm import tqdm | |
class Demo: | |
def __init__(self) -> None: | |
self.training = False | |
self.generating = False | |
self.nsteps = 50 | |
self.diffuser = StableDiffuser(scheduler='DDIM', seed=42).to('cuda') | |
self.finetuner = None | |
with gr.Blocks() as demo: | |
self.layout() | |
demo.queue(concurrency_count=2).launch() | |
def disable(self): | |
return [gr.update(interactive=False), gr.update(interactive=False)] | |
def layout(self): | |
with gr.Row(): | |
self.explain = gr.HTML(interactive=False, | |
value="<p>This page demonstrates Erasing Concepts in Stable Diffusion (Ganikota, Materzynska, Fiotto-Kaufman and Bau; paper and code linked from https://erasing.baulab.info/). <br> Use it in two steps <br> 1. First, on the left fine-tune your own custom model by naming the concept that you want to erase. For example, you can try erasing all cars from a model by entering the prompt corresponding to the concept to erase as 'car'. This can take awhile. For example, with the default settings, this can take about an hour. <br> 2. Second, on the right once you have your model fine-tuned, you can try running it in inference. <br>If you want to run it yourself, then you can create your own instance. Configuration, code, and details are at https://github.com/xxxx/xxxx/xxx</p>") | |
with gr.Row(): | |
with gr.Column(scale=1) as training_column: | |
self.prompt_input = gr.Text( | |
placeholder="Enter prompt...", | |
label="Prompt to Erase", | |
info="Prompt corresponding to concept to erase" | |
) | |
self.train_method_input = gr.Dropdown( | |
choices=['ESD-x', 'ESD-self'], | |
value='ESD-x', | |
label='Train Method', | |
info='Method of training' | |
) | |
self.neg_guidance_input = gr.Number( | |
value=1, | |
label="Negative Guidance", | |
info='Guidance of negative training used to train' | |
) | |
self.iterations_input = gr.Number( | |
value=150, | |
precision=0, | |
label="Iterations", | |
info='iterations used to train' | |
) | |
self.lr_input = gr.Number( | |
value=1e-5, | |
label="Learning Rate", | |
info='Learning rate used to train' | |
) | |
self.train_button = gr.Button( | |
value="Train", | |
) | |
self.download = gr.Files() | |
with gr.Column(scale=2) as inference_column: | |
with gr.Row(): | |
with gr.Column(scale=5): | |
self.prompt_input_infr = gr.Text( | |
placeholder="Enter prompt...", | |
label="Prompt", | |
info="Prompt to generate" | |
) | |
with gr.Column(scale=1): | |
self.seed_infr = gr.Number( | |
label="Seed", | |
value=42 | |
) | |
with gr.Row(): | |
self.image_new = gr.Image( | |
label="New Image", | |
interactive=False | |
) | |
self.image_orig = gr.Image( | |
label="Orig Image", | |
interactive=False | |
) | |
with gr.Row(): | |
self.infr_button = gr.Button( | |
value="Generate", | |
interactive=False | |
) | |
self.infr_button.click(self.inference, inputs = [ | |
self.prompt_input_infr, | |
self.seed_infr | |
], | |
outputs=[ | |
self.image_new, | |
self.image_orig | |
] | |
) | |
self.train_button.click(self.disable, | |
outputs=[self.train_button, self.infr_button] | |
) | |
self.train_button.click(self.train, inputs = [ | |
self.prompt_input, | |
self.train_method_input, | |
self.neg_guidance_input, | |
self.iterations_input, | |
self.lr_input | |
], | |
outputs=[self.train_button, self.infr_button, self.download] | |
) | |
def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)): | |
if self.training: | |
return [None, None, None] | |
else: | |
self.training = True | |
del self.finetuner | |
torch.cuda.empty_cache() | |
self.diffuser = self.diffuser.train().float() | |
if train_method == 'ESD-x': | |
modules = ".*attn2$" | |
elif train_method == 'ESD-self': | |
modules = ".*attn1$" | |
finetuner = FineTunedModel(self.diffuser, modules) | |
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr) | |
criteria = torch.nn.MSELoss() | |
pbar = tqdm(range(iterations)) | |
with torch.no_grad(): | |
neutral_text_embeddings = self.diffuser.get_text_embeddings([''],n_imgs=1) | |
positive_text_embeddings = self.diffuser.get_text_embeddings([prompt],n_imgs=1) | |
for i in pbar: | |
with torch.no_grad(): | |
self.diffuser.set_scheduler_timesteps(self.nsteps) | |
optimizer.zero_grad() | |
iteration = torch.randint(1, self.nsteps - 1, (1,)).item() | |
latents = self.diffuser.get_initial_latents(1, 512, 1) | |
with finetuner: | |
latents_steps, _ = self.diffuser.diffusion( | |
latents, | |
positive_text_embeddings, | |
start_iteration=0, | |
end_iteration=iteration, | |
guidance_scale=3, | |
show_progress=False | |
) | |
self.diffuser.set_scheduler_timesteps(1000) | |
iteration = int(iteration / self.nsteps * 1000) | |
positive_latents = self.diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=3) | |
neutral_latents = self.diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=3) | |
with finetuner: | |
negative_latents = self.diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=3) | |
positive_latents.requires_grad = False | |
neutral_latents.requires_grad = False | |
loss = criteria(negative_latents, neutral_latents - (neg_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs | |
loss.backward() | |
optimizer.step() | |
torch.save(finetuner.state_dict(), 'ft.ckpt') | |
self.finetuner = finetuner.eval().half() | |
self.diffuser = self.diffuser.eval().half() | |
torch.cuda.empty_cache() | |
self.training = False | |
return [gr.update(interactive=True), gr.update(interactive=True), 'ft.ckpt'] | |
def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)): | |
if self.generating: | |
return [None, None] | |
else: | |
self.generating = True | |
self.diffuser._seed = seed | |
images = self.diffuser( | |
prompt, | |
n_steps=50, | |
reseed=True | |
) | |
orig_image = images[0][0] | |
torch.cuda.empty_cache() | |
with self.finetuner: | |
images = self.diffuser( | |
prompt, | |
n_steps=50, | |
reseed=True | |
) | |
edited_image = images[0][0] | |
self.generating = False | |
torch.cuda.empty_cache() | |
return edited_image, orig_image | |
demo = Demo() | |