Spaces:
Runtime error
Runtime error
File size: 9,563 Bytes
74fa5e8 483825e 8c6e31a 74fa5e8 483825e 74fa5e8 8c6e31a 179bc9a 74fa5e8 179bc9a 74fa5e8 8c6e31a 6ac6d43 74fa5e8 483825e 74fa5e8 483825e 74fa5e8 483825e 74fa5e8 179bc9a 74fa5e8 179bc9a 74fa5e8 179bc9a 74fa5e8 179bc9a 74fa5e8 ffd9258 74fa5e8 179bc9a 74fa5e8 179bc9a 74fa5e8 179bc9a 74fa5e8 179bc9a 74fa5e8 179bc9a 74fa5e8 179bc9a 74fa5e8 179bc9a 74fa5e8 179bc9a 74fa5e8 179bc9a 74fa5e8 179bc9a 74fa5e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 |
#!/usr/bin/env python3
"""
Utility functions for the application
Author: Shilpaj Bhalerao
Date: Feb 26, 2025
"""
import torch
import gc
import os
import sys
from PIL import Image, ImageDraw, ImageFont
# Disable HF transfer to avoid download issues
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
# Create a monkey patch for the cached_download function
# This is needed because newer versions of huggingface_hub
# removed cached_download but diffusers still tries to import it
def apply_huggingface_patch():
import importlib
import huggingface_hub
# Check if cached_download is already available
if hasattr(huggingface_hub, 'cached_download'):
return # No need to patch
# Create a wrapper around hf_hub_download to mimic the old cached_download
def cached_download(*args, **kwargs):
# Forward to the new function with appropriate args
return huggingface_hub.hf_hub_download(*args, **kwargs)
# Add the function to the huggingface_hub module
setattr(huggingface_hub, 'cached_download', cached_download)
# Make sure diffusers.utils.dynamic_modules_utils sees the patched module
if 'diffusers.utils.dynamic_modules_utils' in sys.modules:
del sys.modules['diffusers.utils.dynamic_modules_utils']
def load_models(device="cuda"):
"""
Load the necessary models for stable diffusion
:param device: (str) Device to load models on ('cuda', 'mps', or 'cpu')
:return: (tuple) (vae, tokenizer, text_encoder, unet, scheduler, pipe)
"""
# Apply the patch before importing diffusers
apply_huggingface_patch()
# Now we can safely import from diffusers
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, StableDiffusionPipeline
from transformers import CLIPTokenizer, CLIPTextModel
# Set device
if device == "cuda" and not torch.cuda.is_available():
device = "mps" if torch.backends.mps.is_available() else "cpu"
if device == "mps":
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
print(f"Loading models on {device}...")
# Load the autoencoder model which will be used to decode the latents into image space
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_safetensors=False)
# Load the tokenizer and text encoder to tokenize and encode the text
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
# The UNet model for generating the latents
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_safetensors=False)
# The noise scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
# Load the full pipeline for concept loading
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_safetensors=False
)
# Move models to device
vae = vae.to(device)
text_encoder = text_encoder.to(device)
unet = unet.to(device)
pipe = pipe.to(device)
return vae, tokenizer, text_encoder, unet, scheduler, pipe
def clear_gpu_memory():
"""
Clear GPU memory cache
"""
torch.cuda.empty_cache()
gc.collect()
def set_timesteps(scheduler, num_inference_steps):
"""
Set timesteps for the scheduler with MPS compatibility fix
:param scheduler: (Scheduler) Scheduler to set timesteps for
:param num_inference_steps: (int) Number of inference steps
"""
scheduler.set_timesteps(num_inference_steps)
scheduler.timesteps = scheduler.timesteps.to(torch.float32)
def pil_to_latent(input_im, vae, device):
"""
Convert the image to latents
:param input_im: (PIL.Image) Input PIL image
:param vae: (VAE) VAE model
:param device: (str) Device to run on
:return: (torch.Tensor) Latents from VAE's encoder
"""
from torchvision import transforms as tfms
# Single image -> single latent in a batch (so size 1, 4, 64, 64)
with torch.no_grad():
latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(device)*2-1) # Note scaling
return 0.18215 * latent.latent_dist.sample()
def latents_to_pil(latents, vae):
"""
Convert the latents to images
:param latents: (torch.Tensor) Latent tensor
:param vae: (VAE) VAE model
:return: (list) PIL images
"""
# batch of latents -> list of images
latents = (1 / 0.18215) * latents
with torch.no_grad():
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def image_grid(imgs, rows, cols, labels=None):
"""
Create a grid of images with optional labels.
:param imgs: (list) List of PIL images to be arranged in a grid
:param rows: (int) Number of rows in the grid
:param cols: (int) Number of columns in the grid
:param labels: (list, optional) List of label strings for each image
:return: (PIL.Image) A single image with all input images arranged in a grid and labeled
"""
assert len(imgs) == rows*cols, f"Number of images ({len(imgs)}) must equal rows*cols ({rows*cols})"
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h + 30 if labels else rows*h))
# Add padding at the bottom for labels if they exist
label_height = 30 if labels else 0
# Paste images
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
# Add labels if provided
if labels:
assert len(labels) == len(imgs), "Number of labels must match number of images"
draw = ImageDraw.Draw(grid)
# Try to use a standard font, fall back to default if not available
try:
font = ImageFont.truetype("arial.ttf", 14)
except IOError:
font = ImageFont.load_default()
for i, label in enumerate(labels):
# Position text under the image
x = (i % cols) * w + 10
y = (i // cols + 1) * h - 5
# Draw black text with white outline for visibility
# White outline (draw text in each direction)
for offset in [(1,1), (-1,-1), (1,-1), (-1,1)]:
draw.text((x+offset[0], y+offset[1]), label, fill=(255,255,255), font=font)
# Main text (black)
draw.text((x, y), label, fill=(0,0,0), font=font)
return grid
def vignette_loss(images, vignette_strength=3.0, color_shift=[1.0, 0.5, 0.0]):
"""
Creates a strong vignette effect (dark corners) and color shift.
:param images: (torch.Tensor) Batch of images from VAE decoder (range 0-1)
:param vignette_strength: (float) How strong the darkening effect is (higher = more dramatic)
:param color_shift: (list) RGB color to shift the center toward [r, g, b]
:return: (torch.Tensor) Loss value
"""
batch_size, channels, height, width = images.shape
# Create coordinate grid centered at 0 with range [-1, 1]
y = torch.linspace(-1, 1, height).view(-1, 1).repeat(1, width).to(images.device)
x = torch.linspace(-1, 1, width).view(1, -1).repeat(height, 1).to(images.device)
# Calculate radius from center (normalized [0,1])
radius = torch.sqrt(x.pow(2) + y.pow(2)) / 1.414
# Vignette mask: dark at edges, bright in center
vignette = torch.exp(-vignette_strength * radius)
# Color shift target: shift center toward specified color
color_tensor = torch.tensor(color_shift, dtype=torch.float32).view(1, 3, 1, 1).to(images.device)
center_mask = 1.0 - radius.unsqueeze(0).unsqueeze(0)
center_mask = torch.pow(center_mask, 2.0) # Make the transition more dramatic
# Target image with vignette and color shift
target = images.clone()
# Apply vignette (multiply all channels by vignette mask)
for c in range(channels):
target[:, c] = target[:, c] * vignette
# Apply color shift in center
for c in range(channels):
# Shift toward target color more in center, less at edges
color_offset = (color_tensor[:, c] - images[:, c]) * center_mask
target[:, c] = target[:, c] + color_offset.squeeze(1)
# Calculate loss - how different current image is from our target
return torch.pow(images - target, 2).mean()
def get_concept_embedding(concept_text, tokenizer, text_encoder, device):
"""
Generate CLIP embedding for a concept described in text
:param concept_text: (str) Text description of the concept (e.g., "sketch painting")
:param tokenizer: (CLIPTokenizer) CLIP tokenizer
:param text_encoder: (CLIPTextModel) CLIP text encoder
:param device: (str) Device to run on
:return: (torch.Tensor) CLIP embedding for the concept
"""
# Tokenize the concept text
concept_tokens = tokenizer(
concept_text,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
).input_ids.to(device)
# Generate the embedding using the text encoder
with torch.no_grad():
concept_embedding = text_encoder(concept_tokens)[0]
return concept_embedding
|