Gosula's picture
Create utils.py
607bd68
raw
history blame
2.44 kB
from base64 import b64encode
import torch
import numpy as np
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
from huggingface_hub import notebook_login
import torch.nn.functional as F
# For video display:
from IPython.display import HTML
from matplotlib import pyplot as plt
from pathlib import Path
from PIL import Image
from torch import autocast
from torchvision import transforms as tfms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, logging
import os
from device import torch_device,vae,text_encoder,unet,tokenizer,scheduler,token_emb_layer,pos_emb_layer,position_embeddings
# Supress some unnecessary warnings when loading the CLIPTextModel
logging.set_verbosity_error()
def pil_to_latent(input_im):
# Single image -> single latent in a batch (so size 1, 4, 64, 64)
with torch.no_grad():
latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
return 0.18215 * latent.latent_dist.sample()
def latents_to_pil(latents):
# batch of latents -> list of images
latents = (1 / 0.18215) * latents
with torch.no_grad():
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def set_timesteps(scheduler, num_inference_steps):
scheduler.set_timesteps(num_inference_steps)
scheduler.timesteps = scheduler.timesteps.to(torch.float32)
def orange_loss(image):
# Convert the image to a NumPy array
#image = image.float() # Convert to a more standard data type (float32)
#image_np = image.detach().cpu().numpy() # Use .detach() and .cpu() to ensure compatibility
# Extract the orange channel (e.g., Red and Green channels)
orange_channel = image[:, 0, :, :] + image[:, 1, :, :]
# Calculate the mean intensity of the orange channel
#orange_mean = np.mean(orange_channel)
# Define the target mean intensity you desire
target_mean = 0.8 # Replace with your desired mean intensity
# Calculate the loss based on the squared difference from the target
loss = torch.abs(orange_channel- target_mean).mean()
# Convert the loss to a PyTorch tensor
#loss = torch.tensor(loss, dtype=image.dtype)
return loss