Spaces:
Runtime error
Runtime error
import gradio as gr | |
from train_esd import train_esd | |
ckpt_path = "stable-diffusion/models/ldm/sd-v1-4-full-ema.ckpt" | |
config_path = "stable-diffusion/configs/stable-diffusion/v1-inference.yaml" | |
diffusers_config_path = "stable-diffusion/config.json" | |
def train(prompt, train_method, neg_guidance, iterations, lr): | |
train_esd(prompt, | |
train_method, | |
3, | |
neg_guidance, | |
iterations, | |
lr, | |
config_path, | |
ckpt_path, | |
diffusers_config_path, | |
['cuda'] | |
) | |
with gr.Blocks() as demo: | |
prompt_input = gr.Text( | |
placeholder="Enter prompt...", | |
label="Prompt", | |
info="Prompt corresponding to concept to erase" | |
) | |
train_method_input = gr.Dropdown( | |
choices=['noxattn', 'selfattn', 'xattn', 'full'], | |
value='xattn', | |
label='Train Method', | |
info='Method of training' | |
) | |
neg_guidance_input = gr.Number( | |
value=1, | |
label="Negative Guidance", | |
info='Guidance of negative training used to train' | |
) | |
iterations_input = gr.Number( | |
value=1000, | |
precision=0, | |
label="Iterations", | |
info='iterations used to train' | |
) | |
lr_input = gr.Number( | |
value=1e-5, | |
label="Iterations", | |
info='Learning rate used to train' | |
) | |
train_button = gr.Button( | |
value="Train", | |
) | |
train_button.click(train, inputs = [ | |
prompt_input, | |
train_method_input, | |
neg_guidance_input, | |
iterations_input, | |
lr_input | |
] | |
) | |
demo.launch() |