JadenFK commited on
Commit
e066869
1 Parent(s): d9007d5

ckpt download as well as torch bug fix

Browse files
Files changed (2) hide show
  1. app.py +6 -3
  2. requirements.txt +2 -2
app.py CHANGED
@@ -64,11 +64,12 @@ class Demo:
64
  label="Learning Rate",
65
  info='Learning rate used to train'
66
  )
67
- self.progress_bar = gr.Text(interactive=False, label="Training Progress")
68
 
69
  self.train_button = gr.Button(
70
  value="Train",
71
  )
 
 
72
 
73
  with gr.Column(scale=2) as inference_column:
74
 
@@ -125,7 +126,7 @@ class Demo:
125
  self.iterations_input,
126
  self.lr_input
127
  ],
128
- outputs=[self.train_button, self.infr_button, self.progress_bar]
129
  )
130
 
131
  def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
@@ -201,6 +202,8 @@ class Demo:
201
  loss.backward()
202
  optimizer.step()
203
 
 
 
204
  self.finetuner = finetuner.eval().half()
205
 
206
  self.diffuser = self.diffuser.eval().half()
@@ -209,7 +212,7 @@ class Demo:
209
 
210
  self.training = False
211
 
212
- return [gr.update(interactive=True), gr.update(interactive=True), None]
213
 
214
 
215
  def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)):
64
  label="Learning Rate",
65
  info='Learning rate used to train'
66
  )
 
67
 
68
  self.train_button = gr.Button(
69
  value="Train",
70
  )
71
+
72
+ self.download = gr.Files()
73
 
74
  with gr.Column(scale=2) as inference_column:
75
 
126
  self.iterations_input,
127
  self.lr_input
128
  ],
129
+ outputs=[self.train_button, self.infr_button, self.download]
130
  )
131
 
132
  def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
202
  loss.backward()
203
  optimizer.step()
204
 
205
+ torch.save(finetuner.state_dict(), 'ft.ckpt')
206
+
207
  self.finetuner = finetuner.eval().half()
208
 
209
  self.diffuser = self.diffuser.eval().half()
212
 
213
  self.training = False
214
 
215
+ return [gr.update(interactive=True), gr.update(interactive=True), 'ft.ckpt']
216
 
217
 
218
  def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)):
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  gradio
2
- torch
3
- torchvision
4
  diffusers
5
  transformers
6
  accelerate
1
  gradio
2
+ torch==1.13.1 --index-url https://download.pytorch.org/whl/cu118
3
+ torchvision==0.14.1 --index-url https://download.pytorch.org/whl/cu118
4
  diffusers
5
  transformers
6
  accelerate