Spaces:
Runtime error
Runtime error
from base64 import b64encode | |
import torch | |
import numpy as np | |
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel | |
from huggingface_hub import notebook_login | |
import torch.nn.functional as F | |
# For video display: | |
from IPython.display import HTML | |
from matplotlib import pyplot as plt | |
from pathlib import Path | |
from PIL import Image | |
from torch import autocast | |
from torchvision import transforms as tfms | |
from tqdm.auto import tqdm | |
from transformers import CLIPTextModel, CLIPTokenizer, logging | |
import os | |
from device import torch_device,vae,text_encoder,unet,tokenizer,scheduler,token_emb_layer,pos_emb_layer,position_embeddings | |
# Supress some unnecessary warnings when loading the CLIPTextModel | |
logging.set_verbosity_error() | |
def pil_to_latent(input_im): | |
# Single image -> single latent in a batch (so size 1, 4, 64, 64) | |
with torch.no_grad(): | |
latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling | |
return 0.18215 * latent.latent_dist.sample() | |
def latents_to_pil(latents): | |
# batch of latents -> list of images | |
latents = (1 / 0.18215) * latents | |
with torch.no_grad(): | |
image = vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | |
images = (image * 255).round().astype("uint8") | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def set_timesteps(scheduler, num_inference_steps): | |
scheduler.set_timesteps(num_inference_steps) | |
scheduler.timesteps = scheduler.timesteps.to(torch.float32) | |
def orange_loss(image): | |
# Convert the image to a NumPy array | |
#image = image.float() # Convert to a more standard data type (float32) | |
#image_np = image.detach().cpu().numpy() # Use .detach() and .cpu() to ensure compatibility | |
# Extract the orange channel (e.g., Red and Green channels) | |
orange_channel = image[:, 0, :, :] + image[:, 1, :, :] | |
# Calculate the mean intensity of the orange channel | |
#orange_mean = np.mean(orange_channel) | |
# Define the target mean intensity you desire | |
target_mean = 0.8 # Replace with your desired mean intensity | |
# Calculate the loss based on the squared difference from the target | |
loss = torch.abs(orange_channel- target_mean).mean() | |
# Convert the loss to a PyTorch tensor | |
#loss = torch.tensor(loss, dtype=image.dtype) | |
return loss |