File size: 5,126 Bytes
5c762ce
 
 
 
 
 
 
 
 
 
 
 
 
33a2e0a
 
5c762ce
 
 
 
 
 
 
 
 
 
 
 
 
3a45ac7
 
 
 
 
 
 
 
 
 
 
5c762ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
import torch
import copy

import time

ORIGINAL_CHECKPOINT_ID = "CompVis/stable-diffusion-v1-4"
COMPRESSED_UNET_PATH = "checkpoints/BK-SDM-Small_iter50000"

DEVICE='cuda'
# DEVICE='cpu'

class SdmCompressionDemo:
    def __init__(self, device) -> None:
        self.device = device
        self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32
        
        self.pipe_original = StableDiffusionPipeline.from_pretrained(ORIGINAL_CHECKPOINT_ID,
                                                                     torch_dtype=self.torch_dtype)        
        self.pipe_compressed = copy.deepcopy(self.pipe_original)
        self.pipe_compressed.unet = UNet2DConditionModel.from_pretrained(COMPRESSED_UNET_PATH, 
                                                                         subfolder="unet",
                                                                         torch_dtype=self.torch_dtype)
        if 'cuda' in self.device:
            self.pipe_original = self.pipe_original.to(self.device)
            self.pipe_compressed = self.pipe_compressed.to(self.device)
        self.device_msg = 'Tested on GPU.' if 'cuda' in self.device else 'Tested on CPU.'

    def _count_params(self, model):
        return sum(p.numel() for p in model.parameters())

    def get_sdm_params(self, pipe):
        params_unet = self._count_params(pipe.unet)
        params_text_enc = self._count_params(pipe.text_encoder)
        params_image_dec = self._count_params(pipe.vae.decoder)
        params_total = params_unet + params_text_enc + params_image_dec
        return f"Total {(params_total/1e6):.1f}M (U-Net {(params_unet/1e6):.1f}M)"


    def generate_image(self, pipe, text, negative, guidance_scale, steps, seed):
        generator = torch.Generator(self.device).manual_seed(seed)
        start = time.time()
        result = pipe(text, negative_prompt = negative, generator = generator, 
                      guidance_scale = guidance_scale, num_inference_steps = steps)
        test_time = time.time() - start   

        image = result.images[0]
        nsfw_detected = result.nsfw_content_detected[0]     
        print(f"text {text} | Processed time: {test_time} sec | nsfw_flag {nsfw_detected}")
        print(f"negative {negative} | guidance_scale {guidance_scale} | steps {steps} ")
        print("===========")

        return image, nsfw_detected, format(test_time, ".2f")
    
    def error_msg(self, nsfw_detected):
        if nsfw_detected:
            return self.device_msg+" Black images are returned when potential harmful content is detected. Try different prompts or seeds."
        else:
            return self.device_msg

    def check_invalid_input(self, text):
        if text == '':
            return True

    def infer_original_model(self, text, negative, guidance_scale, steps, seed):
        print(f"=== ORIG model --- seed {seed}")        
        if self.check_invalid_input(text):
            return None, "Please enter the input prompt.", None
        output_image, nsfw_detected, test_time = self.generate_image(self.pipe_original, 
                                                       text, negative, guidance_scale, steps, seed)
        
        return output_image, self.error_msg(nsfw_detected), test_time
        
    def infer_compressed_model(self, text, negative, guidance_scale, steps, seed):
        print(f"=== COMPRESSED model --- seed {seed}")
        if self.check_invalid_input(text):
            return None, "Please enter the input prompt.", None
        output_image, nsfw_detected, test_time = self.generate_image(self.pipe_compressed, 
                                                       text, negative, guidance_scale, steps, seed) 

        return output_image, self.error_msg(nsfw_detected), test_time
    

    def get_example_list(self):
        return [
            'a tropical bird sitting on a branch of a tree',
            'many decorative umbrellas hanging up',
            'an orange cat staring off with pretty eyes',
            'beautiful woman face with fancy makeup',
            'a decorated living room with a stylish feel',
            'a black vase holding a bouquet of roses',
            'very elegant bedroom featuring natural wood',
            'buffet-style food including cake and cheese',
            'a tall castle sitting under a cloudy sky',
            'closeup of a brown bear sitting in a grassy area',
            'a large basket with many fresh vegetables',    
            'house being built with lots of wood',
            'a close up of a pizza with several toppings',   
            'a golden vase with many different flows',    
            'a statue of a lion face attached to brick wall',    
            'something that looks particularly interesting',
            'table filled with a variety of different dishes', 
            'a cinematic view of a large snowy peak',
            'a grand city in the year 2100, hyper realistic',
            'a blue eyed baby girl looking at the camera',
        ]