Gosula commited on
Commit
607bd68
·
1 Parent(s): 40026ad

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +69 -0
utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from base64 import b64encode
2
+
3
+ import torch
4
+ import numpy as np
5
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
6
+ from huggingface_hub import notebook_login
7
+ import torch.nn.functional as F
8
+ # For video display:
9
+ from IPython.display import HTML
10
+ from matplotlib import pyplot as plt
11
+ from pathlib import Path
12
+ from PIL import Image
13
+ from torch import autocast
14
+ from torchvision import transforms as tfms
15
+ from tqdm.auto import tqdm
16
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
17
+ import os
18
+ from device import torch_device,vae,text_encoder,unet,tokenizer,scheduler,token_emb_layer,pos_emb_layer,position_embeddings
19
+
20
+
21
+
22
+ # Supress some unnecessary warnings when loading the CLIPTextModel
23
+ logging.set_verbosity_error()
24
+
25
+ def pil_to_latent(input_im):
26
+ # Single image -> single latent in a batch (so size 1, 4, 64, 64)
27
+ with torch.no_grad():
28
+ latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
29
+ return 0.18215 * latent.latent_dist.sample()
30
+
31
+
32
+ def latents_to_pil(latents):
33
+ # batch of latents -> list of images
34
+ latents = (1 / 0.18215) * latents
35
+ with torch.no_grad():
36
+ image = vae.decode(latents).sample
37
+ image = (image / 2 + 0.5).clamp(0, 1)
38
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
39
+ images = (image * 255).round().astype("uint8")
40
+ pil_images = [Image.fromarray(image) for image in images]
41
+ return pil_images
42
+
43
+
44
+ def set_timesteps(scheduler, num_inference_steps):
45
+ scheduler.set_timesteps(num_inference_steps)
46
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32)
47
+
48
+
49
+ def orange_loss(image):
50
+ # Convert the image to a NumPy array
51
+ #image = image.float() # Convert to a more standard data type (float32)
52
+ #image_np = image.detach().cpu().numpy() # Use .detach() and .cpu() to ensure compatibility
53
+
54
+ # Extract the orange channel (e.g., Red and Green channels)
55
+ orange_channel = image[:, 0, :, :] + image[:, 1, :, :]
56
+
57
+ # Calculate the mean intensity of the orange channel
58
+ #orange_mean = np.mean(orange_channel)
59
+
60
+ # Define the target mean intensity you desire
61
+ target_mean = 0.8 # Replace with your desired mean intensity
62
+
63
+ # Calculate the loss based on the squared difference from the target
64
+ loss = torch.abs(orange_channel- target_mean).mean()
65
+
66
+ # Convert the loss to a PyTorch tensor
67
+ #loss = torch.tensor(loss, dtype=image.dtype)
68
+
69
+ return loss