JadenFK's picture
Init for demo
640a27b
raw
history blame
No virus
1.66 kB
import gradio as gr
from train_esd import train_esd
ckpt_path = "stable-diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
config_path = "stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
diffusers_config_path = "stable-diffusion/config.json"
def train(prompt, train_method, neg_guidance, iterations, lr):
train_esd(prompt,
train_method,
3,
neg_guidance,
iterations,
lr,
config_path,
ckpt_path,
diffusers_config_path,
['cuda']
)
with gr.Blocks() as demo:
prompt_input = gr.Text(
placeholder="Enter prompt...",
label="Prompt",
info="Prompt corresponding to concept to erase"
)
train_method_input = gr.Dropdown(
choices=['noxattn', 'selfattn', 'xattn', 'full'],
value='xattn',
label='Train Method',
info='Method of training'
)
neg_guidance_input = gr.Number(
value=1,
label="Negative Guidance",
info='Guidance of negative training used to train'
)
iterations_input = gr.Number(
value=1000,
precision=0,
label="Iterations",
info='iterations used to train'
)
lr_input = gr.Number(
value=1e-5,
label="Iterations",
info='Learning rate used to train'
)
train_button = gr.Button(
value="Train",
)
train_button.click(train, inputs = [
prompt_input,
train_method_input,
neg_guidance_input,
iterations_input,
lr_input
]
)
demo.launch()