ysharma HF staff commited on
Commit
dbedcf7
1 Parent(s): c03ec2b

updating examples

Browse files
Files changed (1) hide show
  1. app.py +17 -5
app.py CHANGED
@@ -21,14 +21,19 @@ torch.manual_seed(1)
21
  counter = 0
22
 
23
  #Getting Lora fine-tuned weights
24
- def monkeypatching(alpha, in_prompt): #, prompt, pipe): finetuned_lora_weights
25
  print("****** inside monkeypatching *******")
26
  print(f"in_prompt is - {str(in_prompt)}")
27
  global counter
28
  if counter == 0 :
29
- monkeypatch_lora(pipe.unet, torch.load("./output_example/lora_weight.pt")) #finetuned_lora_weights
30
- tune_lora_scale(pipe.unet, alpha) #1.00)
31
- counter +=1
 
 
 
 
 
32
  else :
33
  tune_lora_scale(pipe.unet, alpha) #1.00)
34
  prompt = "style of hclu, " + str(in_prompt) #"baby lion"
@@ -77,9 +82,16 @@ with gr.Blocks() as demo:
77
  in_steps = gr.Number(label="Enter the number of training steps", value = 4000)
78
  in_alpha = gr.Slider(0.1,1.0, step=0.01, label="Set Alpha level", value=0.5)
79
  out_file = gr.File(label="Lora trained model weights", )
 
 
 
 
 
 
 
80
 
81
  b1.click(fn = accelerate_train_lora, inputs=in_steps, outputs=out_file)
82
- b2.click(fn = monkeypatching, inputs=[in_alpha, in_prompt], outputs=out_image)
83
 
84
  demo.queue(concurrency_count=3)
85
  demo.launch(debug=True, show_error=True)
 
21
  counter = 0
22
 
23
  #Getting Lora fine-tuned weights
24
+ def monkeypatching(alpha, in_prompt, example_wt): #, prompt, pipe): finetuned_lora_weights
25
  print("****** inside monkeypatching *******")
26
  print(f"in_prompt is - {str(in_prompt)}")
27
  global counter
28
  if counter == 0 :
29
+ if example_wt is None :
30
+ monkeypatch_lora(pipe.unet, torch.load("./output_example/lora_weight.pt")) #finetuned_lora_weights
31
+ tune_lora_scale(pipe.unet, alpha) #1.00)
32
+ counter +=1
33
+ else:
34
+ monkeypatch_lora(pipe.unet, torch.load(example_wt)) #finetuned_lora_weights
35
+ tune_lora_scale(pipe.unet, alpha) #1.00)
36
+ counter +=1
37
  else :
38
  tune_lora_scale(pipe.unet, alpha) #1.00)
39
  prompt = "style of hclu, " + str(in_prompt) #"baby lion"
 
82
  in_steps = gr.Number(label="Enter the number of training steps", value = 4000)
83
  in_alpha = gr.Slider(0.1,1.0, step=0.01, label="Set Alpha level", value=0.5)
84
  out_file = gr.File(label="Lora trained model weights", )
85
+
86
+ gr.Examples(
87
+ examples=[[0.65, "lion", "./lora_playgroundai_wt.pt" ]],
88
+ inputs=[in_alpha, in_prompt, example_wt],
89
+ outputs=out_image,
90
+ fn=monkeypatching,
91
+ cache_examples=True,)
92
 
93
  b1.click(fn = accelerate_train_lora, inputs=in_steps, outputs=out_file)
94
+ b2.click(fn = monkeypatching, inputs=[in_alpha, in_prompt, example_wt], outputs=out_image)
95
 
96
  demo.queue(concurrency_count=3)
97
  demo.launch(debug=True, show_error=True)