Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- config.py +20 -0
- diffusion_loss.py +24 -0
- image_generator.py +73 -0
- model.py +16 -0
- prediction.py +24 -0
- utils.py +24 -0
config.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
DEVICE = "mps"
|
4 |
+
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
5 |
+
HEIGHT = 512
|
6 |
+
WIDTH = 512
|
7 |
+
GUIDANCE_SCALE = 8
|
8 |
+
LOSS_SCALE = 200
|
9 |
+
NUM_INFERENCE_STEPS = 50
|
10 |
+
BATCH_SIZE = 1
|
11 |
+
|
12 |
+
SEEDS = [2000,2000,500,600,100]
|
13 |
+
STABLE_DIFFUSION_MODEL = "CompVis/stable-diffusion-v1-4"
|
14 |
+
STABLE_DIFUSION_CONCEPTS = ['<meeg>', '<midjourney-style>', '<moebius>', ' <Marc_Allante>', '<wlop-style>']
|
15 |
+
|
16 |
+
#LMS DSCRETE SCHEDULER
|
17 |
+
BETA_START = 0.00085
|
18 |
+
BETA_END = 0.012
|
19 |
+
BETA_SCHEDULE = "scaled_linear"
|
20 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
diffusion_loss.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as T
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
def blue_channel(images):
|
6 |
+
error = torch.abs(images[:,2] - 0.9).mean()
|
7 |
+
return error
|
8 |
+
|
9 |
+
def elastic_transform(images):
|
10 |
+
elastic_transformer = T.ElasticTransform(alpha=550.0,sigma=5.0)
|
11 |
+
transformed_imgs = elastic_transformer(images)
|
12 |
+
error = torch.abs(transformed_imgs - images).mean()
|
13 |
+
return error
|
14 |
+
|
15 |
+
def symmetry(images):
|
16 |
+
flipped_image = torch.flip(images, [3])
|
17 |
+
error = F.mse_loss(images, flipped_image)
|
18 |
+
print("Loss Calculated for the Symmetry : ", error)
|
19 |
+
return error
|
20 |
+
|
21 |
+
def saturation(images):
|
22 |
+
transformed_imgs = T.functional.adjust_saturation(images,saturation_factor = 10)
|
23 |
+
error = torch.abs(transformed_imgs - images).mean()
|
24 |
+
return error
|
image_generator.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm.auto import tqdm
|
3 |
+
from diffusers import LMSDiscreteScheduler
|
4 |
+
|
5 |
+
import config
|
6 |
+
|
7 |
+
def construct_text_embeddings(pipe, prompt):
|
8 |
+
text_input = pipe.tokenizer(prompt, padding='max_length',
|
9 |
+
max_length = pipe.tokenizer.model_max_length, truncation= True,
|
10 |
+
return_tensors="pt")
|
11 |
+
uncond_input = pipe.tokenizer([""] * config.BATCH_SIZE, padding="max_length",
|
12 |
+
max_length= text_input.input_ids.shape[-1],
|
13 |
+
return_tensors="pt")
|
14 |
+
with torch.no_grad():
|
15 |
+
text_input_embeddings = pipe.text_encoder(text_input.input_ids.to(config.DEVICE))[0]
|
16 |
+
with torch.no_grad():
|
17 |
+
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(config.DEVICE))[0]
|
18 |
+
|
19 |
+
text_embeddings = torch.cat([uncond_embeddings, text_input_embeddings])
|
20 |
+
return text_embeddings
|
21 |
+
|
22 |
+
def initialize_latent(seed_number, pipe, scheduler):
|
23 |
+
generator = torch.manual_seed(seed_number)
|
24 |
+
latent = torch.randn((config.BATCH_SIZE, pipe.unet.config.in_channels,
|
25 |
+
config.HEIGHT//8, config.WIDTH//8),
|
26 |
+
generator = generator).to(torch.float16)
|
27 |
+
latent = latent.to(config.DEVICE)
|
28 |
+
latent = latent * scheduler.init_noise_sigma
|
29 |
+
return latent
|
30 |
+
|
31 |
+
def run_prediction(pipe, text_embeddings, scheduler, latent, loss_function=None):
|
32 |
+
for i, t in tqdm(enumerate(scheduler.timesteps), total = len(scheduler.timesteps)):
|
33 |
+
latent_model_input = torch.cat([latent] * 2)
|
34 |
+
sigma = scheduler.sigmas[i]
|
35 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
36 |
+
|
37 |
+
with torch.no_grad():
|
38 |
+
noise_pred = pipe.unet(latent_model_input.to(torch.float16), t, encoder_hidden_states=text_embeddings)["sample"]
|
39 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
40 |
+
noise_pred = noise_pred_uncond + config.GUIDANCE_SCALE * (noise_pred_text - noise_pred_uncond)
|
41 |
+
|
42 |
+
if loss_function and i%5 == 0:
|
43 |
+
latent = latent.detach().requires_grad_()
|
44 |
+
latent_x0 = latent - sigma * noise_pred
|
45 |
+
|
46 |
+
denoised_images = pipe.vae.decode((1/ 0.18215) * latent_x0).sample / 2 + 0.5 # range(0,1)
|
47 |
+
|
48 |
+
loss = loss_function(denoised_images) * config.LOSS_SCALE
|
49 |
+
print(f"loss {loss}")
|
50 |
+
|
51 |
+
cond_grad = torch.autograd.grad(loss, latent)[0]
|
52 |
+
latent = latent.detach() - cond_grad * sigma**2
|
53 |
+
|
54 |
+
latent = scheduler.step(noise_pred,t, latent).prev_sample
|
55 |
+
|
56 |
+
return latent
|
57 |
+
|
58 |
+
def generate_images(pipe, seed_number, prompt, loss_function=None):
|
59 |
+
|
60 |
+
scheduler = LMSDiscreteScheduler(beta_start = 0.00085,
|
61 |
+
beta_end = 0.012,
|
62 |
+
beta_schedule = "scaled_linear",
|
63 |
+
num_train_timesteps = 1000)
|
64 |
+
scheduler.set_timesteps(config.NUM_INFERENCE_STEPS)
|
65 |
+
scheduler.timesteps = scheduler.timesteps.to(torch.float32)
|
66 |
+
|
67 |
+
text_embeddings = construct_text_embeddings(pipe=pipe, prompt=prompt)
|
68 |
+
latent = initialize_latent(seed_number=seed_number, pipe=pipe, scheduler=scheduler)
|
69 |
+
latent = run_prediction(pipe=pipe, text_embeddings=text_embeddings,
|
70 |
+
scheduler=scheduler, latent=latent,
|
71 |
+
loss_function=loss_function)
|
72 |
+
|
73 |
+
return latent
|
model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import DiffusionPipeline
|
3 |
+
import config
|
4 |
+
|
5 |
+
def initialize_diffusion_model():
|
6 |
+
pretrained_model_name_or_path = config.STABLE_DIFFUSION_MODEL
|
7 |
+
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path,
|
8 |
+
torch_dtype=torch.float16).to(config.DEVICE)
|
9 |
+
|
10 |
+
pipe.load_textual_inversion("sd-concepts-library/dreams")
|
11 |
+
pipe.load_textual_inversion("sd-concepts-library/midjourney-style")
|
12 |
+
pipe.load_textual_inversion("sd-concepts-library/moebius")
|
13 |
+
pipe.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
|
14 |
+
pipe.load_textual_inversion("sd-concepts-library/wlop-style")
|
15 |
+
|
16 |
+
return pipe
|
prediction.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gc
|
3 |
+
|
4 |
+
import utils
|
5 |
+
import model
|
6 |
+
import config
|
7 |
+
import image_generator as generator
|
8 |
+
|
9 |
+
def predict(prompt, pipe, loss_function=None):
|
10 |
+
latents = []
|
11 |
+
|
12 |
+
for seed_number, sd_concept in zip(config.SEEDS, config.STABLE_DIFUSION_CONCEPTS):
|
13 |
+
torch.mps.empty_cache()
|
14 |
+
gc.collect()
|
15 |
+
torch.mps.empty_cache()
|
16 |
+
|
17 |
+
prompt = [f'{prompt} {sd_concept}']
|
18 |
+
latent = generator.generate_images(pipe=pipe, seed_number=seed_number, prompt=prompt, loss_function=loss_function)
|
19 |
+
latents.append(latent)
|
20 |
+
|
21 |
+
latents = torch.vstack(latents)
|
22 |
+
images = utils.convert_latents_to_pil_images(pipe=pipe, latents=latents)
|
23 |
+
grid = utils.populate_image_grid(images, 1, len(latents))
|
24 |
+
return grid
|
utils.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from diffusers import LMSDiscreteScheduler
|
4 |
+
|
5 |
+
import config
|
6 |
+
|
7 |
+
def convert_latents_to_pil_images(pipe, latents):
|
8 |
+
latents = (1 / 0.18215) * latents
|
9 |
+
with torch.no_grad():
|
10 |
+
image = pipe.vae.decode(latents).sample
|
11 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
12 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
13 |
+
images = (image * 255).round().astype("uint8")
|
14 |
+
pil_images = [Image.fromarray(image) for image in images]
|
15 |
+
return pil_images
|
16 |
+
|
17 |
+
def populate_image_grid(imgs, rows, cols):
|
18 |
+
assert len(imgs) == rows*cols
|
19 |
+
w, h = imgs[0].size
|
20 |
+
grid = Image.new('RGB', size=(cols*w, rows*h))
|
21 |
+
|
22 |
+
for i, img in enumerate(imgs):
|
23 |
+
grid.paste(img, box=(i%cols*w, i//cols*h))
|
24 |
+
return grid
|