Spaces:
Running
on
T4
Running
on
T4
import gradio as gr | |
from PIL import Image | |
import torch | |
import re | |
import os | |
import requests | |
from customization import customize_vae_decoder | |
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DDIMScheduler, EulerDiscreteScheduler | |
from torchvision import transforms | |
from attribution import MappingNetwork | |
import math | |
from typing import List | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import torch | |
is_gpu_busy = False | |
PRETRAINED_MODEL_NAME_OR_PATH = "./checkpoints/" | |
def get_image_grid(images: List[Image.Image]) -> Image: | |
num_images = len(images) | |
cols = 3#int(math.ceil(math.sqrt(num_images))) | |
rows = 1#int(math.ceil(num_images / cols)) | |
width, height = images[0].size | |
grid_image = Image.new('RGB', (cols * width, rows * height)) | |
for i, img in enumerate(images): | |
x = i % cols | |
y = i // cols | |
grid_image.paste(img, (x * width, y * height)) | |
return grid_image | |
class AttributionModel: | |
def __init__(self): | |
self.pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2')#, safety_checker=None, torch_dtype=torch.float16) | |
self.pipe = self.pipe.to("cuda") | |
self.resize_transform = transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR) | |
self.vae = AutoencoderKL.from_pretrained( | |
'stabilityai/stable-diffusion-2', subfolder="vae" | |
) | |
self.vae = customize_vae_decoder(self.vae, 128, "qkv", "all", False, 1.0) | |
self.mapping_network = MappingNetwork(32, 0, 128, None, num_layers=2, w_avg_beta=None, normalization = False).to("cuda") | |
from torchvision.models import resnet50, ResNet50_Weights | |
self.decoding_network = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) | |
self.decoding_network.fc = torch.nn.Linear(2048,32) | |
self.vae.decoder.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'vae_decoder.pth'))) | |
self.mapping_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'mapping_network.pth'))) | |
self.decoding_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'decoding_network.pth'))) | |
self.vae = self.vae.to("cuda") | |
self.mapping_network = self.mapping_network.to("cuda") | |
self.decoding_network = self.decoding_network.to("cuda") | |
self.test_norm = transforms.Compose( | |
[ | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
] | |
) | |
def infer(self, prompt, negative, guidance_scale): | |
images = [] | |
with torch.no_grad(): | |
out_latents = self.pipe([prompt], output_type="latent", num_inference_steps=10, guidance_scale=guidance_scale).images | |
image = self.inference_with_attribution(out_latents) | |
print(image[0]) | |
# image = self.pipe.numpy_to_pil(image) | |
# image[0].save("im1.jpg") | |
return [image[0]]*3 #, "caption") #get_image_grid(images) | |
def inference_without_attribution(self, latents): | |
latents = 1 / 0.18215 * latents | |
with torch.no_grad(): | |
image = self.pipe.vae.decode(latents).sample | |
image = image.clamp(-1,1) | |
return image | |
def get_phis(self, phi_dimension, batch_size ,eps = 1e-8): | |
phi_length = phi_dimension | |
b = batch_size | |
phi = torch.empty(b,phi_length).uniform_(0,1) | |
return torch.bernoulli(phi) + eps | |
def inference_with_attribution(self, latents, key=None): | |
if key==None: | |
key = self.get_phis(32, 1) | |
latents = 1 / 0.18215 * latents | |
with torch.no_grad(): | |
image = self.vae.decode(latents, self.mapping_network(key.cuda())).sample | |
image = image.clamp(-1,1) | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
return image | |
def postprocess(self, image): | |
image = self.resize_transform(image) | |
return image | |
def detect_key(self, image): | |
reconstructed_keys = self.decoding_network(self.test_norm((image / 2 + 0.5).clamp(0, 1))) | |
return reconstructed_keys | |
attribution_model = AttributionModel() | |
with gr.Blocks() as demo: | |
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True): | |
with gr.Column(): | |
text = gr.Textbox( | |
label="Enter your prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
elem_id="prompt-text-input", | |
).style( | |
border=(True, False, True, True), | |
rounded=(True, False, False, True), | |
container=False, | |
) | |
negative = gr.Textbox( | |
label="Enter your negative prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter a negative prompt", | |
elem_id="negative-prompt-text-input", | |
).style( | |
border=(True, False, True, True), | |
rounded=(True, False, False, True), | |
container=False, | |
) | |
btn = gr.Button("Generate image").style(full_width=False) | |
with gr.Row(): | |
img_output_simple = gr.Image() | |
img_output_attribute = gr.Image() | |
img_output_diff = gr.Image() | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", minimum=0, maximum=10, value=9, step=0.1 | |
) | |
btn.click(attribution_model.infer, inputs=[text, negative, guidance_scale], outputs=[img_output_simple, img_output_attribute, img_output_diff], postprocess=False) | |
if __name__=="__main__": | |
demo.queue(concurrency_count=1, max_size=20).launch(max_threads=50) |