Sambhavnoobcoder commited on
Commit
8c97ab4
1 Parent(s): ad80697

Upload 3 files

Browse files
Files changed (3) hide show
  1. Utils.py +88 -0
  2. app (1).py +83 -0
  3. requirements.txt +6 -0
Utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
2
+ from transformers import CLIPTextModel, CLIPTokenizer
3
+ from tqdm.auto import tqdm
4
+ from PIL import Image
5
+ import torch
6
+
7
+ class MingleModel:
8
+
9
+ def __init__(self):
10
+ # Set device
11
+ self.torch_device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ # Load the autoencoder model which will be used to decode the latents into image space.
13
+ use_auth_token = "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"
14
+ self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae",
15
+ use_auth_token=use_auth_token).to(self.torch_device)
16
+
17
+ # Load the tokenizer and text encoder to tokenize and encode the text.
18
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", use_auth_token=use_auth_token)
19
+ self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", use_auth_token=use_auth_token).to(self.torch_device)
20
+
21
+ # # The UNet model for generating the latents.
22
+ self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet",use_auth_token=use_auth_token).to(self.torch_device)
23
+
24
+ # The noise scheduler
25
+ self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
26
+ num_train_timesteps=1000)
27
+
28
+ def do_tokenizer(self, prompt):
29
+ return self.tokenizer([prompt], padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True,
30
+ return_tensors="pt")
31
+
32
+ def get_text_encoder(self, text_input):
33
+ return self.text_encoder(text_input.input_ids.to(self.torch_device))[0]
34
+
35
+ def latents_to_pil(self, latents):
36
+ # bath of latents -> list of images
37
+ latents = (1 / 0.18215) * latents
38
+ with torch.no_grad():
39
+ image = self.vae.decode(latents).sample
40
+ image = (image / 2 + 0.5).clamp(0, 1)
41
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
42
+ images = (image * 255).round().astype("uint8")
43
+ pil_images = [Image.fromarray(image) for image in images]
44
+ return pil_images
45
+
46
+ def generate_with_embs(self, text_embeddings, generator_int=32, num_inference_steps=30, guidance_scale=7.5):
47
+ height = 512 # default height of Stable Diffusion
48
+ width = 512 # default width of Stable Diffusion
49
+ num_inference_steps = num_inference_steps # Number of denoising steps
50
+ guidance_scale = guidance_scale # Scale for classifier-free guidance
51
+ generator = torch.manual_seed(generator_int) # Seed generator to create the inital latent noise
52
+ batch_size = 1
53
+
54
+ max_length = 77
55
+ uncond_input = self.tokenizer(
56
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
57
+ )
58
+ with torch.no_grad():
59
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.torch_device))[0]
60
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
61
+
62
+ # Prep Scheduler
63
+ self.scheduler.set_timesteps(num_inference_steps)
64
+
65
+ # Prep latents
66
+ latents = torch.randn((batch_size, self.unet.in_channels, height // 8, width // 8), generator=generator)
67
+ latents = latents.to(self.torch_device)
68
+ latents = latents * self.scheduler.init_noise_sigma
69
+
70
+ # Loop
71
+ for i, t in tqdm(enumerate(self.scheduler.timesteps)):
72
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
73
+ latent_model_input = torch.cat([latents] * 2)
74
+ sigma = self.scheduler.sigmas[i]
75
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
76
+
77
+ # predict the noise residual
78
+ with torch.no_grad():
79
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
80
+
81
+ # perform guidance
82
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
83
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
84
+
85
+ # compute the previous noisy sample x_t -> x_t-1
86
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
87
+
88
+ return self.latents_to_pil(latents)[0]
app (1).py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import logging
4
+ import random
5
+ from PIL import Image
6
+ from Utils import MingleModel
7
+
8
+ logging.set_verbosity_error()
9
+
10
+
11
+ def get_concat_h(images):
12
+ widths, heights = zip(*(i.size for i in images))
13
+
14
+ total_width = sum(widths)
15
+ max_height = max(heights)
16
+
17
+ dst = Image.new('RGB', (total_width, max_height))
18
+ x_offset = 0
19
+ for im in images:
20
+ dst.paste(im, (x_offset,0))
21
+ x_offset += im.size[0]
22
+ return dst
23
+
24
+
25
+ mingle_model = MingleModel()
26
+
27
+
28
+ def mingle_prompts(first_prompt, second_prompt):
29
+ imgs = []
30
+ text_input1 = mingle_model.do_tokenizer(first_prompt)
31
+ text_input2 = mingle_model.do_tokenizer(second_prompt)
32
+ with torch.no_grad():
33
+ text_embeddings1 = mingle_model.get_text_encoder(text_input1)
34
+ text_embeddings2 = mingle_model.get_text_encoder(text_input2)
35
+
36
+ rand_generator = random.randint(1, 2048)
37
+ # Mix them together
38
+ # mix_factors = [0.1, 0.3, 0.5, 0.7, 0.9]
39
+ mix_factors = [0.5]
40
+ for mix_factor in mix_factors:
41
+ mixed_embeddings = (text_embeddings1 * mix_factor + text_embeddings2 * (1 - mix_factor))
42
+
43
+ # Generate!
44
+ steps = 20
45
+ guidence_scale = 8.0
46
+ img = mingle_model.generate_with_embs(mixed_embeddings, rand_generator, num_inference_steps=steps,
47
+ guidance_scale=guidence_scale)
48
+ imgs.append(img)
49
+
50
+ return get_concat_h(imgs)
51
+
52
+
53
+ with gr.Blocks() as demo:
54
+ gr.Markdown(
55
+ '''
56
+ <h1 style="text-align: center;"> Fashion Generator GAN</h1>
57
+ ''')
58
+
59
+ gr.Markdown(
60
+ '''
61
+ <h3 style="text-align: center;"> Note : the gan is extremely resource extensive, so it running the inference on cpu takes long time . kindly wait patiently while the model generates the output. </h3>
62
+ ''')
63
+
64
+ gr.Markdown(
65
+ '''
66
+ <p style="text-align: center;">generated an image as an average of 2 prompts inserted !!</p>
67
+ ''')
68
+
69
+ first_prompt = gr.Textbox(label="first_prompt")
70
+ second_prompt = gr.Textbox(label="second_prompt")
71
+ greet_btn = gr.Button("Submit")
72
+ # gr.Markdown("## Text Examples")
73
+ # gr.Examples([['batman, dynamic lighting, photorealistic fantasy concept art, trending on art station, stunning visuals, terrifying, creative, cinematic',
74
+ # 'venom, dynamic lighting, photorealistic fantasy concept art, trending on art station, stunning visuals, terrifying, creative, cinematic'],
75
+ # ['A mouse', 'A leopard']], [first_prompt, second_prompt])
76
+
77
+ gr.Markdown("# Output Results")
78
+ output = gr.Image(shape=(512,512))
79
+
80
+ greet_btn.click(fn=mingle_prompts, inputs=[first_prompt, second_prompt], outputs=[output])
81
+
82
+ demo.launch()
83
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers
2
+ diffusers
3
+ ftfy
4
+ torch
5
+ gradio
6
+ scipy