Spaces:
Runtime error
Runtime error
bokyeong1015
commited on
Commit
·
3a45ac7
1
Parent(s):
a9b3bf8
add nparams count
Browse files
app.py
CHANGED
@@ -50,20 +50,25 @@ if __name__ == "__main__":
|
|
50 |
with gr.Tab("Example Prompts"):
|
51 |
examples = gr.Examples(examples=example_list, inputs=[text])
|
52 |
|
53 |
-
with gr.Column(variant='panel',
|
54 |
# Define original model output components
|
55 |
gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>')
|
56 |
original_model_output = gr.Image(label="Original Model")
|
57 |
with gr.Row().style(equal_height=True):
|
58 |
-
|
|
|
|
|
59 |
original_model_error = gr.Markdown()
|
|
|
60 |
|
61 |
-
with gr.Column(variant='panel',
|
62 |
# Define compressed model output components
|
63 |
gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>')
|
64 |
-
compressed_model_output = gr.Image(label="Compressed Model")
|
65 |
with gr.Row().style(equal_height=True):
|
66 |
-
|
|
|
|
|
67 |
compressed_model_error = gr.Markdown()
|
68 |
|
69 |
inputs = [text, negative, guidance_scale, steps, seed]
|
|
|
50 |
with gr.Tab("Example Prompts"):
|
51 |
examples = gr.Examples(examples=example_list, inputs=[text])
|
52 |
|
53 |
+
with gr.Column(variant='panel',scale=35):
|
54 |
# Define original model output components
|
55 |
gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>')
|
56 |
original_model_output = gr.Image(label="Original Model")
|
57 |
with gr.Row().style(equal_height=True):
|
58 |
+
with gr.Column():
|
59 |
+
original_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
|
60 |
+
original_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_original), label="# Parameters")
|
61 |
original_model_error = gr.Markdown()
|
62 |
+
|
63 |
|
64 |
+
with gr.Column(variant='panel',scale=35):
|
65 |
# Define compressed model output components
|
66 |
gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>')
|
67 |
+
compressed_model_output = gr.Image(label="Compressed Model")
|
68 |
with gr.Row().style(equal_height=True):
|
69 |
+
with gr.Column():
|
70 |
+
compressed_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
|
71 |
+
compressed_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_compressed), label="# Parameters")
|
72 |
compressed_model_error = gr.Markdown()
|
73 |
|
74 |
inputs = [text, negative, guidance_scale, steps, seed]
|
demo.py
CHANGED
@@ -26,6 +26,17 @@ class SdmCompressionDemo:
|
|
26 |
self.pipe_compressed = self.pipe_compressed.to(self.device)
|
27 |
self.device_msg = 'Tested on GPU.' if 'cuda' in self.device else 'Tested on CPU.'
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def generate_image(self, pipe, text, negative, guidance_scale, steps, seed):
|
30 |
generator = torch.Generator(self.device).manual_seed(seed)
|
31 |
start = time.time()
|
|
|
26 |
self.pipe_compressed = self.pipe_compressed.to(self.device)
|
27 |
self.device_msg = 'Tested on GPU.' if 'cuda' in self.device else 'Tested on CPU.'
|
28 |
|
29 |
+
def _count_params(self, model):
|
30 |
+
return sum(p.numel() for p in model.parameters())
|
31 |
+
|
32 |
+
def get_sdm_params(self, pipe):
|
33 |
+
params_unet = self._count_params(pipe.unet)
|
34 |
+
params_text_enc = self._count_params(pipe.text_encoder)
|
35 |
+
params_image_dec = self._count_params(pipe.vae.decoder)
|
36 |
+
params_total = params_unet + params_text_enc + params_image_dec
|
37 |
+
return f"Total {(params_total/1e6):.1f}M (U-Net {(params_unet/1e6):.1f}M)"
|
38 |
+
|
39 |
+
|
40 |
def generate_image(self, pipe, text, negative, guidance_scale, steps, seed):
|
41 |
generator = torch.Generator(self.device).manual_seed(seed)
|
42 |
start = time.time()
|