mist-v2 / mist-webui.py
AeroXi's picture
Upload folder using huggingface_hub
ece766c
import gradio as gr
from attacks.mist import update_args_with_config, main
'''
TODO:
- SDXL
- model changing
'''
def process_image(eps, device, mode, resize, data_path, output_path, model_path, class_path, prompt, \
class_prompt, max_train_steps, max_f_train_steps, max_adv_train_steps, lora_lr, pgd_lr, \
rank, prior_loss_weight, fused_weight, constraint_mode, lpips_bound, lpips_weight):
config = (eps, device, mode, resize, data_path, output_path, model_path, class_path, prompt, \
class_prompt, max_train_steps, max_f_train_steps, max_adv_train_steps, lora_lr, pgd_lr, \
rank, prior_loss_weight, fused_weight, constraint_mode, lpips_bound, lpips_weight)
args = None
args = update_args_with_config(args, config)
main(args)
if __name__ == "__main__":
with gr.Blocks() as demo:
with gr.Column():
gr.Image("MIST_logo.png", show_label=False)
with gr.Row():
with gr.Column():
eps = gr.Slider(0, 32, step=1, value=10, label='Strength',
info="Larger strength results in stronger but more visible defense.")
device = gr.Radio(["cpu", "gpu"], value="gpu", label="Device",
info="If you do not have good GPUs on your PC, choose 'CPU'.")
# precision = gr.Radio(["float16", "bfloat16"], value="bfloat16", label="Precision",
# info="Precision used in computing")
resize = gr.Checkbox(value=True, label="Resizing the output image to the original resolution")
mode = gr.Radio(["Mode 1", "Mode 2", "Mode 3"], value="Mode 1", label="Mode",
info="Two modes both work with different visualization.")
# model_type = gr.Radio(["Stable Diffusion", "SDXL"], value="Stable Diffusion", label="Target Model",
# info="Model used by imaginary copyright infringers")
data_path = gr.Textbox(label="Data Path", lines=1, placeholder="Path to your images")
output_path = gr.Textbox(label="Output Path", lines=1, placeholder="Path to store the outputs")
model_path = gr.Textbox(label="Target Model Path", lines=1, placeholder="Path to the target model")
class_path = gr.Textbox(label="Path to place contrast images ", lines=1, placeholder="Path to the target model")
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Describe your images")
with gr.Accordion("Professional Setups", open=False):
class_prompt = gr.Textbox(label="Class prompt", lines=1, placeholder="Prompt for contrast images.")
max_train_steps = gr.Slider(1, 20, step=1, value=5, label='Epochs',
info="Training epochs of Mist-V2")
max_f_train_steps = gr.Slider(0, 30, step=1, value=10, label='LoRA Steps',
info="Training steps of LoRA in one epoch")
max_adv_train_steps = gr.Slider(0, 100, step=5, value=30, label='Attacking Steps',
info="Training steps of attacking in one epoch")
lora_lr = gr.Number(minimum=0.0, maximum=1.0, label="The learning rate of LoRA", value=0.0001)
pgd_lr = gr.Number(minimum=0.0, maximum=1.0, label="The learning rate of PGD", value=0.005)
rank = gr.Slider(4, 32, step=4, value=4, label='LoRA Ranks',
info="Ranks of LoRA (Bigger ranks need better GPUs)")
prior_loss_weight = gr.Number(minimum=0.0, maximum=1.0, label="The weight of prior loss", value=0.1)
fused_weight = gr.Number(minimum=0.0, maximum=1.0, label="The weight of vae loss", value=0.00001)
constraint_mode = gr.Radio(["Epsilon", "LPIPS"], value="Epsilon", label="Constraint Mode",
info="The mode to constraint the watermark")
lpips_bound = gr.Number(minimum=0.0, maximum=0.2, label="The LPIPS bound", value=0.1)
lpips_weight = gr.Number(minimum=0.0, maximum=1.0, label="The weight of LPIPI constraint", value=0.5)
# inputs = [eps, device, precision, mode, model_type, original_resolution, data_path, \
# output_path, model_path, prompt, max_f_train_steps, max_train_steps, max_adv_train_steps, lora_lr, pgd_lr, rank]
inputs = [eps, device, mode, resize, data_path, output_path, model_path, class_path, prompt, \
class_prompt, max_train_steps, max_f_train_steps, max_adv_train_steps, lora_lr, pgd_lr, \
rank, prior_loss_weight, fused_weight, constraint_mode, lpips_bound, lpips_weight]
image_button = gr.Button("Mist")
image_button.click(process_image, inputs=inputs)
demo.queue().launch(share=True)