RohitGandikota commited on
Commit
958100a
1 Parent(s): 239a0d3

changed ft.pt to prompt-based

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -262,7 +262,8 @@ class Demo:
262
  loss.backward()
263
  optimizer.step()
264
 
265
- torch.save(finetuner.state_dict(), f'{prompt.lower().replace(' ', '')}.pt')
 
266
 
267
  self.finetuner = finetuner.eval().half()
268
 
@@ -272,9 +273,9 @@ class Demo:
272
 
273
  self.training = False
274
 
275
- model_map['Custom'] = f'{prompt.lower().replace(' ', '')}.pt'
276
 
277
- return [gr.update(interactive=True), gr.update(interactive=True), f'{prompt.lower().replace(' ', '')}.pt', gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
278
 
279
 
280
  def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)):
 
262
  loss.backward()
263
  optimizer.step()
264
 
265
+ ft_path = f"{prompt.lower().replace(' ', '')}.pt"
266
+ torch.save(finetuner.state_dict(), ft_path)
267
 
268
  self.finetuner = finetuner.eval().half()
269
 
 
273
 
274
  self.training = False
275
 
276
+ model_map['Custom'] = ft_path
277
 
278
+ return [gr.update(interactive=True), gr.update(interactive=True), ft_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
279
 
280
 
281
  def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)):