Spaces:
Runtime error
Runtime error
File size: 1,657 Bytes
640a27b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import gradio as gr
from train_esd import train_esd
ckpt_path = "stable-diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
config_path = "stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
diffusers_config_path = "stable-diffusion/config.json"
def train(prompt, train_method, neg_guidance, iterations, lr):
train_esd(prompt,
train_method,
3,
neg_guidance,
iterations,
lr,
config_path,
ckpt_path,
diffusers_config_path,
['cuda']
)
with gr.Blocks() as demo:
prompt_input = gr.Text(
placeholder="Enter prompt...",
label="Prompt",
info="Prompt corresponding to concept to erase"
)
train_method_input = gr.Dropdown(
choices=['noxattn', 'selfattn', 'xattn', 'full'],
value='xattn',
label='Train Method',
info='Method of training'
)
neg_guidance_input = gr.Number(
value=1,
label="Negative Guidance",
info='Guidance of negative training used to train'
)
iterations_input = gr.Number(
value=1000,
precision=0,
label="Iterations",
info='iterations used to train'
)
lr_input = gr.Number(
value=1e-5,
label="Iterations",
info='Learning rate used to train'
)
train_button = gr.Button(
value="Train",
)
train_button.click(train, inputs = [
prompt_input,
train_method_input,
neg_guidance_input,
iterations_input,
lr_input
]
)
demo.launch() |