JadenFK commited on
Commit
f0cf9b0
1 Parent(s): 843b14b
Files changed (3) hide show
  1. app.py +85 -78
  2. requirements.txt +2 -0
  3. train_esd.py +1 -1
app.py CHANGED
@@ -7,104 +7,110 @@ from omegaconf import OmegaConf
7
  from StableDiffuser import StableDiffuser
8
  from diffusers import UNet2DConditionModel
9
 
10
- ckpt_path = "stable-diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
11
- config_path = "stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
12
- diffusers_config_path = "stable-diffusion/config.json"
13
 
14
 
15
  class Demo:
16
 
17
  def __init__(self) -> None:
18
- demo = self.layout()
19
- demo.launch()
20
 
 
 
 
 
 
 
21
 
22
  def layout(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- with gr.Row():
27
- with gr.Column() as training_column:
28
- self.prompt_input = gr.Text(
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  placeholder="Enter prompt...",
30
  label="Prompt",
31
  info="Prompt corresponding to concept to erase"
32
  )
33
- self.train_method_input = gr.Dropdown(
34
- choices=['noxattn', 'selfattn', 'xattn', 'full'],
35
- value='xattn',
36
- label='Train Method',
37
- info='Method of training'
38
- )
39
 
40
- self.neg_guidance_input = gr.Number(
41
- value=1,
42
- label="Negative Guidance",
43
- info='Guidance of negative training used to train'
44
- )
45
 
46
- self.iterations_input = gr.Number(
47
- value=1000,
48
- precision=0,
49
- label="Iterations",
50
- info='iterations used to train'
51
  )
52
-
53
- self.lr_input = gr.Number(
54
- value=1e-5,
55
- label="Learning Rate",
56
- info='Learning rate used to train'
57
  )
58
 
59
- self.train_button = gr.Button(
60
- value="Train",
 
 
 
61
  )
62
- self.train_button.click(self.train, inputs = [
63
- self.prompt_input,
64
- self.train_method_input,
65
- self.neg_guidance_input,
66
- self.iterations_input,
67
- self.lr_input
68
  ]
69
  )
70
- with gr.Column() as inference_column:
71
-
72
- with gr.Row():
73
-
74
- self.prompt_input_infr = gr.Text(
75
- placeholder="Enter prompt...",
76
- label="Prompt",
77
- info="Prompt corresponding to concept to erase"
78
- )
79
-
80
- with gr.Row():
81
-
82
- self.image_new = gr.Image(
83
- label="New Image",
84
- interactive=False
85
- )
86
- self.image_orig = gr.Image(
87
- label="Orig Image",
88
- interactive=False
89
- )
90
-
91
- with gr.Row():
92
-
93
- self.infr_button = gr.Button(
94
- value="Generate",
95
- )
96
- self.infr_button.click(self.inference, inputs = [
97
- self.prompt_input_infr,
98
- ],
99
- outputs=[
100
- self.image_new,
101
- self.image_orig
102
- ]
103
- )
104
- return demo
105
-
106
 
107
- def train(self, prompt, train_method, neg_guidance, iterations, lr):
108
 
109
  model_orig, model_edited = train_esd(prompt,
110
  train_method,
@@ -115,8 +121,7 @@ class Demo:
115
  config_path,
116
  ckpt_path,
117
  diffusers_config_path,
118
- ['cuda', 'cuda'],
119
- gr.Progress()
120
  )
121
 
122
  original_config = OmegaConf.load(config_path)
@@ -127,6 +132,8 @@ class Demo:
127
 
128
  self.init_inference(model_edited_sd, model_orig_sd, unet_config)
129
 
 
 
130
  def init_inference(self, model_edited_sd, model_orig_sd, unet_config):
131
 
132
  self.model_edited_sd = model_edited_sd
@@ -163,5 +170,5 @@ class Demo:
163
  return edited_image, orig_image
164
 
165
 
166
-
167
 
 
7
  from StableDiffuser import StableDiffuser
8
  from diffusers import UNet2DConditionModel
9
 
10
+ ckpt_path = "stable_diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
11
+ config_path = "stable_diffusion/configs/stable-diffusion/v1-inference.yaml"
12
+ diffusers_config_path = "stable_diffusion/config.json"
13
 
14
 
15
  class Demo:
16
 
17
  def __init__(self) -> None:
 
 
18
 
19
+ with gr.Blocks() as demo:
20
+ self.layout()
21
+ demo.queue(concurrency_count=10).launch()
22
+
23
+ def disable(self):
24
+ return [gr.update(interactive=False), gr.update(interactive=False)]
25
 
26
  def layout(self):
27
+ with gr.Row():
28
+ with gr.Column() as training_column:
29
+ self.prompt_input = gr.Text(
30
+ placeholder="Enter prompt...",
31
+ label="Prompt",
32
+ info="Prompt corresponding to concept to erase"
33
+ )
34
+ self.train_method_input = gr.Dropdown(
35
+ choices=['noxattn', 'selfattn', 'xattn', 'full'],
36
+ value='xattn',
37
+ label='Train Method',
38
+ info='Method of training'
39
+ )
40
 
41
+ self.neg_guidance_input = gr.Number(
42
+ value=1,
43
+ label="Negative Guidance",
44
+ info='Guidance of negative training used to train'
45
+ )
46
+
47
+ self.iterations_input = gr.Number(
48
+ value=1000,
49
+ precision=0,
50
+ label="Iterations",
51
+ info='iterations used to train'
52
+ )
53
 
54
+ self.lr_input = gr.Number(
55
+ value=1e-5,
56
+ label="Learning Rate",
57
+ info='Learning rate used to train'
58
+ )
59
+ self.train_button = gr.Button(
60
+ value="Train",
61
+
62
+ )
63
+
64
+
65
+ with gr.Column() as inference_column:
66
+
67
+ with gr.Row():
68
+
69
+ self.prompt_input_infr = gr.Text(
70
  placeholder="Enter prompt...",
71
  label="Prompt",
72
  info="Prompt corresponding to concept to erase"
73
  )
 
 
 
 
 
 
74
 
75
+ with gr.Row():
 
 
 
 
76
 
77
+ self.image_new = gr.Image(
78
+ label="New Image",
79
+ interactive=False
 
 
80
  )
81
+ self.image_orig = gr.Image(
82
+ label="Orig Image",
83
+ interactive=False
 
 
84
  )
85
 
86
+ with gr.Row():
87
+
88
+ self.infr_button = gr.Button(
89
+ value="Generate",
90
+ interactive=False
91
  )
92
+ self.infr_button.click(self.inference, inputs = [
93
+ self.prompt_input_infr,
94
+ ],
95
+ outputs=[
96
+ self.image_new,
97
+ self.image_orig
98
  ]
99
  )
100
+ self.train_button.click(self.disable,
101
+ outputs=[self.train_button, self.infr_button]
102
+ )
103
+ self.train_button.click(self.train, inputs = [
104
+ self.prompt_input,
105
+ self.train_method_input,
106
+ self.neg_guidance_input,
107
+ self.iterations_input,
108
+ self.lr_input
109
+ ],
110
+ outputs=[self.train_button, self.infr_button]
111
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
114
 
115
  model_orig, model_edited = train_esd(prompt,
116
  train_method,
 
121
  config_path,
122
  ckpt_path,
123
  diffusers_config_path,
124
+ ['cuda', 'cuda']
 
125
  )
126
 
127
  original_config = OmegaConf.load(config_path)
 
132
 
133
  self.init_inference(model_edited_sd, model_orig_sd, unet_config)
134
 
135
+ return [gr.update(interactive=True), gr.update(interactive=True)]
136
+
137
  def init_inference(self, model_edited_sd, model_orig_sd, unet_config):
138
 
139
  self.model_edited_sd = model_edited_sd
 
170
  return edited_image, orig_image
171
 
172
 
173
+ demo = Demo()
174
 
requirements.txt CHANGED
@@ -7,5 +7,7 @@ transformers
7
  pytorch_lightning==1.6.5
8
  taming-transformers
9
  kornia
 
 
10
  git+https://github.com/openai/CLIP.git@main#egg=clip
11
  git+https://github.com/davidbau/baukit.git
 
7
  pytorch_lightning==1.6.5
8
  taming-transformers
9
  kornia
10
+ scipy
11
+ accelerate
12
  git+https://github.com/openai/CLIP.git@main#egg=clip
13
  git+https://github.com/davidbau/baukit.git
train_esd.py CHANGED
@@ -102,7 +102,7 @@ def get_models(config_path, ckpt_path, devices):
102
 
103
  return model_orig, sampler_orig, model, sampler
104
 
105
- def train_esd(prompt, train_method, start_guidance, negative_guidance, iterations, lr, config_path, ckpt_path, diffusers_config_path, devices, progress_bar, seperator=None, image_size=512, ddim_steps=50):
106
  '''
107
  Function to train diffusion models to erase concepts from model weights
108
 
 
102
 
103
  return model_orig, sampler_orig, model, sampler
104
 
105
+ def train_esd(prompt, train_method, start_guidance, negative_guidance, iterations, lr, config_path, ckpt_path, diffusers_config_path, devices, seperator=None, image_size=512, ddim_steps=50):
106
  '''
107
  Function to train diffusion models to erase concepts from model weights
108