Spaces:
Runtime error
Runtime error
from diffusers import StableDiffusionPipeline, UNet2DConditionModel | |
import torch | |
import copy | |
import time | |
ORIGINAL_CHECKPOINT_ID = "CompVis/stable-diffusion-v1-4" | |
COMPRESSED_UNET_PATH = "checkpoints/BK-SDM-Small_iter50000" | |
DEVICE='cuda' | |
# DEVICE='cpu' | |
class SdmCompressionDemo: | |
def __init__(self) -> None: | |
self.device = DEVICE | |
self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32 | |
self.pipe_original = StableDiffusionPipeline.from_pretrained(ORIGINAL_CHECKPOINT_ID, | |
torch_dtype=self.torch_dtype) | |
self.pipe_compressed = copy.deepcopy(self.pipe_original) | |
self.pipe_compressed.unet = UNet2DConditionModel.from_pretrained(COMPRESSED_UNET_PATH, | |
subfolder="unet", | |
torch_dtype=self.torch_dtype) | |
if 'cuda' in self.device: | |
self.pipe_original = self.pipe_original.to(self.device) | |
self.pipe_compressed = self.pipe_compressed.to(self.device) | |
self.device_msg = 'Tested on GPU.' if 'cuda' in self.device else 'Tested on CPU.' | |
def generate_image(self, pipe, text, negative, guidance_scale, steps, seed): | |
generator = torch.Generator(self.device).manual_seed(seed) | |
start = time.time() | |
result = pipe(text, negative_prompt = negative, generator = generator, | |
guidance_scale = guidance_scale, num_inference_steps = steps) | |
test_time = time.time() - start | |
image = result.images[0] | |
nsfw_detected = result.nsfw_content_detected[0] | |
print(f"text {text} | Processed time: {test_time} sec | nsfw_flag {nsfw_detected}") | |
print(f"negative {negative} | guidance_scale {guidance_scale} | steps {steps} ") | |
print("===========") | |
return image, nsfw_detected, format(test_time, ".2f") | |
def error_msg(self, nsfw_detected): | |
if nsfw_detected: | |
return self.device_msg+" Black images are returned when potential harmful content is detected. Try different prompts or seeds." | |
else: | |
return self.device_msg | |
def check_invalid_input(self, text): | |
if text == '': | |
return True | |
def infer_original_model(self, text, negative, guidance_scale, steps, seed): | |
print(f"=== ORIG model --- seed {seed}") | |
if self.check_invalid_input(text): | |
return None, "Please enter the input prompt.", None | |
output_image, nsfw_detected, test_time = self.generate_image(self.pipe_original, | |
text, negative, guidance_scale, steps, seed) | |
return output_image, self.error_msg(nsfw_detected), test_time | |
def infer_compressed_model(self, text, negative, guidance_scale, steps, seed): | |
print(f"=== COMPRESSED model --- seed {seed}") | |
if self.check_invalid_input(text): | |
return None, "Please enter the input prompt.", None | |
output_image, nsfw_detected, test_time = self.generate_image(self.pipe_compressed, | |
text, negative, guidance_scale, steps, seed) | |
return output_image, self.error_msg(nsfw_detected), test_time | |
def get_example_list(self): | |
return [ | |
'a tropical bird sitting on a branch of a tree', | |
'many decorative umbrellas hanging up', | |
'an orange cat staring off with pretty eyes', | |
'beautiful woman face with fancy makeup', | |
'a decorated living room with a stylish feel', | |
'a black vase holding a bouquet of roses', | |
'very elegant bedroom featuring natural wood', | |
'buffet-style food including cake and cheese', | |
'a tall castle sitting under a cloudy sky', | |
'closeup of a brown bear sitting in a grassy area', | |
'a large basket with many fresh vegetables', | |
'house being built with lots of wood', | |
'a close up of a pizza with several toppings', | |
'a golden vase with many different flows', | |
'a statue of a lion face attached to brick wall', | |
'something that looks particularly interesting', | |
'table filled with a variety of different dishes', | |
'a cinematic view of a large snowy peak', | |
'a grand city in the year 2100, hyper realistic', | |
'a blue eyed baby girl looking at the camera', | |
] | |