JadenFK's picture
Layout changes
0166058
raw
history blame
No virus
6.53 kB
import sys
sys.path.insert(0,'stable_diffusion')
import gradio as gr
from train_esd import train_esd
from convertModels import convert_ldm_unet_checkpoint, create_unet_diffusers_config
from omegaconf import OmegaConf
from StableDiffuser import StableDiffuser
from diffusers import UNet2DConditionModel
ckpt_path = "stable_diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
config_path = "stable_diffusion/configs/stable-diffusion/v1-inference.yaml"
diffusers_config_path = "stable_diffusion/config.json"
class Demo:
def __init__(self) -> None:
self.training = False
self.generating = False
with gr.Blocks() as demo:
self.layout()
demo.queue(concurrency_count=10).launch()
def disable(self):
return [gr.update(interactive=False), gr.update(interactive=False)]
def layout(self):
with gr.Row():
with gr.Column(scale=1) as training_column:
self.prompt_input = gr.Text(
placeholder="Enter prompt...",
label="Prompt to Erase",
info="Prompt corresponding to concept to erase"
)
self.train_method_input = gr.Dropdown(
choices=['noxattn', 'selfattn', 'xattn', 'full'],
value='xattn',
label='Train Method',
info='Method of training'
)
self.neg_guidance_input = gr.Number(
value=1,
label="Negative Guidance",
info='Guidance of negative training used to train'
)
self.iterations_input = gr.Number(
value=1000,
precision=0,
label="Iterations",
info='iterations used to train'
)
self.lr_input = gr.Number(
value=1e-5,
label="Learning Rate",
info='Learning rate used to train'
)
self.progress_bar = gr.Text(interactive=False, label="Training Progress")
self.train_button = gr.Button(
value="Train",
)
with gr.Column(scale=2) as inference_column:
with gr.Row():
with gr.Column(scale=4):
self.prompt_input_infr = gr.Text(
placeholder="Enter prompt...",
label="Prompt",
info="Prompt to generate"
)
with gr.Column(scale=1):
self.seed_infr = gr.Number(
label="Seed",
value=42
)
with gr.Row():
self.image_new = gr.Image(
label="New Image",
interactive=False
)
self.image_orig = gr.Image(
label="Orig Image",
interactive=False
)
with gr.Row():
self.infr_button = gr.Button(
value="Generate",
interactive=False
)
self.infr_button.click(self.inference, inputs = [
self.prompt_input_infr,
self.seed_infr
],
outputs=[
self.image_new,
self.image_orig
]
)
self.train_button.click(self.disable,
outputs=[self.train_button, self.infr_button]
)
self.train_button.click(self.train, inputs = [
self.prompt_input,
self.train_method_input,
self.neg_guidance_input,
self.iterations_input,
self.lr_input
],
outputs=[self.train_button, self.infr_button, self.progress_bar]
)
def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
if self.training:
return [None, None, None]
else:
self.training = True
model_orig, model_edited = train_esd(prompt,
train_method,
3,
neg_guidance,
iterations,
lr,
config_path,
ckpt_path,
diffusers_config_path,
['cuda', 'cuda']
)
original_config = OmegaConf.load(config_path)
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = 4
unet_config = create_unet_diffusers_config(original_config, image_size=512)
model_edited_sd = convert_ldm_unet_checkpoint(model_edited.state_dict(), unet_config)
model_orig_sd = convert_ldm_unet_checkpoint(model_orig.state_dict(), unet_config)
self.init_inference(model_edited_sd, model_orig_sd, unet_config)
return [gr.update(interactive=True), gr.update(interactive=True), None]
def init_inference(self, model_edited_sd, model_orig_sd, unet_config):
self.model_edited_sd = model_edited_sd
self.model_orig_sd = model_orig_sd
self.diffuser = StableDiffuser(42)
self.diffuser.unet = UNet2DConditionModel(**unet_config)
self.diffuser.to('cuda')
self.training = False
def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)):
if self.generating:
return [None, None]
else:
self.generating = True
self.diffuser.unet.load_state_dict(self.model_orig_sd)
self.diffuser._seed = seed
images = self.diffuser(
prompt,
n_steps=50,
reseed=True
)
orig_image = images[0][0]
self.diffuser.unet.load_state_dict(self.model_edited_sd)
images = self.diffuser(
prompt,
n_steps=50,
reseed=True
)
edited_image = images[0][0]
self.generating = False
return edited_image, orig_image
demo = Demo()