bokyeong1015 commited on
Commit
3a45ac7
1 Parent(s): a9b3bf8

add nparams count

Browse files
Files changed (2) hide show
  1. app.py +10 -5
  2. demo.py +11 -0
app.py CHANGED
@@ -50,20 +50,25 @@ if __name__ == "__main__":
50
  with gr.Tab("Example Prompts"):
51
  examples = gr.Examples(examples=example_list, inputs=[text])
52
 
53
- with gr.Column(variant='panel', scale=35):
54
  # Define original model output components
55
  gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>')
56
  original_model_output = gr.Image(label="Original Model")
57
  with gr.Row().style(equal_height=True):
58
- original_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
 
 
59
  original_model_error = gr.Markdown()
 
60
 
61
- with gr.Column(variant='panel', scale=35):
62
  # Define compressed model output components
63
  gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>')
64
- compressed_model_output = gr.Image(label="Compressed Model")
65
  with gr.Row().style(equal_height=True):
66
- compressed_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
 
 
67
  compressed_model_error = gr.Markdown()
68
 
69
  inputs = [text, negative, guidance_scale, steps, seed]
 
50
  with gr.Tab("Example Prompts"):
51
  examples = gr.Examples(examples=example_list, inputs=[text])
52
 
53
+ with gr.Column(variant='panel',scale=35):
54
  # Define original model output components
55
  gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>')
56
  original_model_output = gr.Image(label="Original Model")
57
  with gr.Row().style(equal_height=True):
58
+ with gr.Column():
59
+ original_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
60
+ original_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_original), label="# Parameters")
61
  original_model_error = gr.Markdown()
62
+
63
 
64
+ with gr.Column(variant='panel',scale=35):
65
  # Define compressed model output components
66
  gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>')
67
+ compressed_model_output = gr.Image(label="Compressed Model")
68
  with gr.Row().style(equal_height=True):
69
+ with gr.Column():
70
+ compressed_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
71
+ compressed_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_compressed), label="# Parameters")
72
  compressed_model_error = gr.Markdown()
73
 
74
  inputs = [text, negative, guidance_scale, steps, seed]
demo.py CHANGED
@@ -26,6 +26,17 @@ class SdmCompressionDemo:
26
  self.pipe_compressed = self.pipe_compressed.to(self.device)
27
  self.device_msg = 'Tested on GPU.' if 'cuda' in self.device else 'Tested on CPU.'
28
 
 
 
 
 
 
 
 
 
 
 
 
29
  def generate_image(self, pipe, text, negative, guidance_scale, steps, seed):
30
  generator = torch.Generator(self.device).manual_seed(seed)
31
  start = time.time()
 
26
  self.pipe_compressed = self.pipe_compressed.to(self.device)
27
  self.device_msg = 'Tested on GPU.' if 'cuda' in self.device else 'Tested on CPU.'
28
 
29
+ def _count_params(self, model):
30
+ return sum(p.numel() for p in model.parameters())
31
+
32
+ def get_sdm_params(self, pipe):
33
+ params_unet = self._count_params(pipe.unet)
34
+ params_text_enc = self._count_params(pipe.text_encoder)
35
+ params_image_dec = self._count_params(pipe.vae.decoder)
36
+ params_total = params_unet + params_text_enc + params_image_dec
37
+ return f"Total {(params_total/1e6):.1f}M (U-Net {(params_unet/1e6):.1f}M)"
38
+
39
+
40
  def generate_image(self, pipe, text, negative, guidance_scale, steps, seed):
41
  generator = torch.Generator(self.device).manual_seed(seed)
42
  start = time.time()