Spaces:
Runtime error
Runtime error
import PIL | |
import torch | |
import numpy as np | |
from PIL import Image | |
from tqdm import tqdm | |
import torch.nn.functional as F | |
import torchvision.transforms as T | |
from diffusers import LMSDiscreteScheduler, DiffusionPipeline | |
# configurations | |
torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
height, width = 512, 512 | |
guidance_scale = 8 | |
blue_loss_scale = 200 | |
num_inference_steps = 50 | |
elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0) | |
pretrained_model_name_or_path = "segmind/tiny-sd" | |
pipe = DiffusionPipeline.from_pretrained( | |
pretrained_model_name_or_path, | |
low_cpu_mem_usage = True, | |
torch_dtype=torch.float32 | |
).to(torch_device) | |
pipe.load_textual_inversion("sd-concepts-library/dreams") | |
pipe.load_textual_inversion("sd-concepts-library/midjourney-style") | |
pipe.load_textual_inversion("sd-concepts-library/moebius") | |
pipe.load_textual_inversion("sd-concepts-library/style-of-marc-allante") | |
pipe.load_textual_inversion("sd-concepts-library/wlop-style") | |
concepts_mapping = { | |
"Dream": '<meeg>', "Midjourney":'<midjourney-style>', | |
"Marc Allante": '<Marc_Allante>', "Moebius": '<moebius>', | |
"Wlop": '<wlop-style>' | |
} | |
def image_loss(images, method='elastic'): | |
# elastic loss | |
if method == 'elastic': | |
transformed_imgs = elastic_transformer(images) | |
error = torch.abs(transformed_imgs - images).mean() | |
# symmetry loss - Flip the image along the width | |
elif method == "symmetry": | |
flipped_image = torch.flip(images, [3]) | |
error = F.mse_loss(images, flipped_image) | |
# saturation loss | |
elif method == 'saturation': | |
transformed_imgs = T.functional.adjust_saturation(images,saturation_factor = 10) | |
error = torch.abs(transformed_imgs - images).mean() | |
# blue loss | |
elif method == 'blue': | |
error = torch.abs(images[:,2] - 0.9).mean() # [:,2] -> all images in batch, only the blue channel | |
return error | |
HTML_TEMPLATE = """ | |
<style> | |
body { | |
background: linear-gradient(135deg, #f5f7fa, #c3cfe2); | |
} | |
#app-header { | |
text-align: center; | |
background: rgba(255, 255, 255, 0.8); /* Semi-transparent white */ | |
padding: 20px; | |
border-radius: 10px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
position: relative; /* To position the artifacts */ | |
} | |
#app-header h1 { | |
color: #4CAF50; | |
font-size: 2em; | |
margin-bottom: 10px; | |
} | |
.concept { | |
position: relative; | |
transition: transform 0.3s; | |
} | |
.concept:hover { | |
transform: scale(1.1); | |
} | |
.concept img { | |
width: 100px; | |
border-radius: 10px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
} | |
.concept-description { | |
position: absolute; | |
bottom: -30px; | |
left: 50%; | |
transform: translateX(-50%); | |
background-color: #4CAF50; | |
color: white; | |
padding: 5px 10px; | |
border-radius: 5px; | |
opacity: 0; | |
transition: opacity 0.3s; | |
} | |
.concept:hover .concept-description { | |
opacity: 1; | |
} | |
/* Artifacts */ | |
.artifact { | |
position: absolute; | |
background: rgba(76, 175, 80, 0.1); /* Semi-transparent green */ | |
border-radius: 50%; /* Make it circular */ | |
} | |
.artifact.large { | |
width: 300px; | |
height: 300px; | |
top: -50px; | |
left: -150px; | |
} | |
.artifact.medium { | |
width: 200px; | |
height: 200px; | |
bottom: -50px; | |
right: -100px; | |
} | |
.artifact.small { | |
width: 100px; | |
height: 100px; | |
top: 50%; | |
left: 50%; | |
transform: translate(-50%, -50%); | |
} | |
</style> | |
<div id="app-header"> | |
<!-- Artifacts --> | |
<div class="artifact large"></div> | |
<div class="artifact medium"></div> | |
<div class="artifact small"></div> | |
<!-- Content --> | |
<h1>Art Generator</h1> | |
<p>Generate new art in five different styles by providing a prompt.</p> | |
<div style="display: flex; justify-content: center; gap: 20px; margin-top: 20px;"> | |
<div class="concept"> | |
<img src="https://github.com/Delve-ERAV1/S20/assets/11761529/30ac92f8-fc62-4aab-9221-043865c6fe7c" alt="Midjourney"> | |
<div class="concept-description">Midjourney Style</div> | |
</div> | |
<div class="concept"> | |
<img src="https://github.com/Delve-ERAV1/S20/assets/11761529/54c9a61e-df9f-4054-835b-ec2c6ba5916c" alt="Dreams"> | |
<div class="concept-description">Dreams Style</div> | |
</div> | |
<div class="concept"> | |
<img src="https://github.com/Delve-ERAV1/S20/assets/11761529/2f37e402-15d1-4a74-ba85-bb1566da930e" alt="Moebius"> | |
<div class="concept-description">Moebius Style</div> | |
</div> | |
<div class="concept"> | |
<img src="https://github.com/Delve-ERAV1/S20/assets/11761529/f838e767-ac20-4996-b5be-65c61b365ce0" alt="Allante"> | |
<div class="concept-description">Hong Kong born artist inspired by western and eastern influences</div> | |
</div> | |
<div class="concept"> | |
<img src="https://github.com/Delve-ERAV1/S20/assets/11761529/9958140a-1b62-4972-83ca-85b023e3863f" alt="Wlop"> | |
<div class="concept-description">WLOP (Born 1987) is known for Digital Art (NFTs)</div> | |
</div> | |
</div> | |
</div> | |
""" | |
def get_examples(): | |
examples = [ | |
['A powerful man in dreadlocks', 'Dream', 'Symmetry', 45], | |
['World Peace', 'Marc Allante', 'Saturation', 147], | |
['Storm trooper in the desert, dramatic lighting, high-detail', 'Moebius', 'Elastic', 28], | |
['Delicious Italian pizza on a table, a window in the background overlooking a city skyline', 'Wlop', 'Blue', 50], | |
] | |
return(examples) | |
def latents_to_pil(latents): | |
# bath of latents -> list of images | |
latents = (1 / 0.18215) * latents | |
with torch.no_grad(): | |
image = pipe.vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1) # 0 to 1 | |
image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | |
image = (image * 255).round().astype("uint8") | |
return Image.fromarray(image[0]) | |
def generate_art(prompt, concept, method, seed): | |
prompt = f"{prompt} in the style of {concepts_mapping[concept]}" | |
img_no_loss = latents_to_pil(generate_image(prompt, method, seed)) | |
img_loss = latents_to_pil(generate_image(prompt, method, seed, loss_apply=True)) | |
return([img_no_loss, img_loss]) | |
def generate_image(prompt, method, seed, loss_apply=False): | |
generator = torch.manual_seed(seed) | |
batch_size = 1 | |
method = method.lower() | |
# scheduler | |
scheduler = LMSDiscreteScheduler(beta_start = 0.00085, beta_end = 0.012, beta_schedule = "scaled_linear", num_train_timesteps = 1000) | |
scheduler.set_timesteps(50) | |
scheduler.timesteps = scheduler.timesteps.to(torch.float32) | |
# text embeddings of the prompt | |
text_input = pipe.tokenizer([prompt], padding='max_length', max_length = pipe.tokenizer.model_max_length, truncation= True, return_tensors="pt") | |
input_ids = text_input.input_ids.to(torch_device) | |
with torch.no_grad(): | |
text_embeddings = pipe.text_encoder(text_input.input_ids.to(torch_device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = pipe.tokenizer( | |
[""] * 1, padding="max_length", max_length= max_length, return_tensors="pt" | |
) | |
with torch.no_grad(): | |
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0] | |
text_embeddings = torch.cat([uncond_embeddings,text_embeddings]) | |
# random latent | |
latents = torch.randn( | |
(batch_size, pipe.unet.config.in_channels, height// 8, width //8), | |
generator = generator, | |
).to(torch.float32) | |
latents = latents.to(torch_device) | |
latents = latents * scheduler.init_noise_sigma | |
for i, t in tqdm(enumerate(scheduler.timesteps), total = len(scheduler.timesteps)): | |
latent_model_input = torch.cat([latents] * 2) | |
sigma = scheduler.sigmas[i] | |
latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
with torch.no_grad(): | |
noise_pred = pipe.unet(latent_model_input.to(torch.float32), t, encoder_hidden_states=text_embeddings)["sample"] | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
if loss_apply and i%5 == 0: | |
latents = latents.detach().requires_grad_() | |
latents_x0 = latents - sigma * noise_pred | |
# use vae to decode the image | |
denoised_images = pipe.vae.decode((1/ 0.18215) * latents_x0).sample / 2 + 0.5 # range(0,1) | |
loss = image_loss(denoised_images, method) * blue_loss_scale | |
cond_grad = torch.autograd.grad(loss, latents)[0] | |
latents = latents.detach() - cond_grad * sigma**2 | |
latents = scheduler.step(noise_pred,t, latents).prev_sample | |
return latents |