mygyasir commited on
Commit
a0b0c18
1 Parent(s): 803e17b

Update demo.py

Browse files
Files changed (1) hide show
  1. demo.py +0 -111
demo.py CHANGED
@@ -1,111 +0,0 @@
1
- from diffusers import StableDiffusionPipeline, UNet2DConditionModel
2
- import torch
3
- import copy
4
-
5
- import time
6
-
7
- ORIGINAL_CHECKPOINT_ID = "CompVis/stable-diffusion-v1-4"
8
- COMPRESSED_UNET_ID = "nota-ai/bk-sdm-small"
9
-
10
- DEVICE='cuda'
11
- # DEVICE='cpu'
12
-
13
- class SdmCompressionDemo:
14
- def __init__(self, device) -> None:
15
- self.device = device
16
- self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32
17
-
18
- self.pipe_original = StableDiffusionPipeline.from_pretrained(ORIGINAL_CHECKPOINT_ID,
19
- torch_dtype=self.torch_dtype)
20
- self.pipe_compressed = copy.deepcopy(self.pipe_original)
21
- self.pipe_compressed.unet = UNet2DConditionModel.from_pretrained(COMPRESSED_UNET_ID,
22
- subfolder="unet",
23
- torch_dtype=self.torch_dtype)
24
- if 'cuda' in self.device:
25
- self.pipe_original = self.pipe_original.to(self.device)
26
- self.pipe_compressed = self.pipe_compressed.to(self.device)
27
- self.device_msg = 'Tested on GPU.' if 'cuda' in self.device else 'Tested on CPU.'
28
-
29
- def _count_params(self, model):
30
- return sum(p.numel() for p in model.parameters())
31
-
32
- def get_sdm_params(self, pipe):
33
- params_unet = self._count_params(pipe.unet)
34
- params_text_enc = self._count_params(pipe.text_encoder)
35
- params_image_dec = self._count_params(pipe.vae.decoder)
36
- params_total = params_unet + params_text_enc + params_image_dec
37
- return f"Total {(params_total/1e6):.1f}M (U-Net {(params_unet/1e6):.1f}M)"
38
-
39
-
40
- def generate_image(self, pipe, text, negative, guidance_scale, steps, seed):
41
- generator = torch.Generator(self.device).manual_seed(seed)
42
- start = time.time()
43
- result = pipe(text, negative_prompt = negative, generator = generator,
44
- guidance_scale = guidance_scale, num_inference_steps = steps)
45
- test_time = time.time() - start
46
-
47
- image = result.images[0]
48
- nsfw_detected = result.nsfw_content_detected[0]
49
- print(f"text {text} | Processed time: {test_time} sec | nsfw_flag {nsfw_detected}")
50
- print(f"negative {negative} | guidance_scale {guidance_scale} | steps {steps} ")
51
- print("===========")
52
-
53
- return image, nsfw_detected, format(test_time, ".2f")
54
-
55
- def error_msg(self, nsfw_detected):
56
- if nsfw_detected:
57
- return self.device_msg+" Black images are returned when potential harmful content is detected. Try different prompts or seeds."
58
- else:
59
- return self.device_msg
60
-
61
- def check_invalid_input(self, text):
62
- if text == '':
63
- return True
64
-
65
- def infer_original_model(self, text, negative, guidance_scale, steps, seed):
66
- print(f"=== ORIG model --- seed {seed}")
67
- if self.check_invalid_input(text):
68
- return None, "Please enter the input prompt.", None
69
- output_image, nsfw_detected, test_time = self.generate_image(self.pipe_original,
70
- text, negative, guidance_scale, steps, seed)
71
-
72
- return output_image, self.error_msg(nsfw_detected), test_time
73
-
74
- def infer_compressed_model(self, text, negative, guidance_scale, steps, seed):
75
- print(f"=== COMPRESSED model --- seed {seed}")
76
- if self.check_invalid_input(text):
77
- return None, "Please enter the input prompt.", None
78
- output_image, nsfw_detected, test_time = self.generate_image(self.pipe_compressed,
79
- text, negative, guidance_scale, steps, seed)
80
-
81
- return output_image, self.error_msg(nsfw_detected), test_time
82
-
83
-
84
- def get_example_list(self):
85
- return [
86
- 'a tropical bird sitting on a branch of a tree',
87
- 'many decorative umbrellas hanging up',
88
- 'an orange cat staring off with pretty eyes',
89
- 'beautiful woman face with fancy makeup',
90
- 'a decorated living room with a stylish feel',
91
- 'a black vase holding a bouquet of roses',
92
- 'very elegant bedroom featuring natural wood',
93
- 'buffet-style food including cake and cheese',
94
- 'a tall castle sitting under a cloudy sky',
95
- 'closeup of a brown bear sitting in a grassy area',
96
- 'a large basket with many fresh vegetables',
97
- 'house being built with lots of wood',
98
- 'a close up of a pizza with several toppings',
99
- 'a golden vase with many different flows',
100
- 'a statue of a lion face attached to brick wall',
101
- 'something that looks particularly interesting',
102
- 'table filled with a variety of different dishes',
103
- 'a cinematic view of a large snowy peak',
104
- 'a grand city in the year 2100, hyper realistic',
105
- 'a blue eyed baby girl looking at the camera',
106
- ]
107
-
108
-
109
-
110
-
111
-