Spaces:
Running
on
Zero
Running
on
Zero
from typing import List, Optional | |
import math, random, os | |
import pandas as pd | |
import numpy as np | |
import torch | |
from tqdm.auto import tqdm | |
from sklearn.decomposition import PCA | |
def extract_clip_features(clip, image, encoder): | |
""" | |
Extracts feature embeddings from an image using either CLIP or DINOv2 models. | |
Args: | |
clip (torch.nn.Module): The feature extraction model (either CLIP or DINOv2) | |
image (torch.Tensor): Input image tensor normalized according to model requirements | |
encoder (str): Type of encoder to use ('dinov2-small' or 'clip') | |
Returns: | |
torch.Tensor: Feature embeddings extracted from the image | |
Note: | |
- For DINOv2 models, uses the pooled output features | |
- For CLIP models, uses the image features from the vision encoder | |
- The input image should already be properly resized and normalized | |
""" | |
# Handle DINOv2 models | |
if 'dino' in encoder: | |
denoised = clip(image) | |
denoised = denoised.pooler_output | |
# Handle CLIP models | |
else: | |
denoised = clip.get_image_features(image) | |
return denoised | |
def compute_clip_pca( | |
diverse_prompts: List[str], | |
pipe, | |
clip_model, | |
clip_processor, | |
device, | |
guidance_scale, | |
params, | |
total_samples = 5000, | |
num_pca_components = 100, | |
batch_size = 10 | |
) -> torch.Tensor: | |
""" | |
Extract CLIP features from generated images based on prompts. | |
Args: | |
diverse_prompts: List of prompts to generate images from | |
model_components: Various model components needed for generation | |
args: Training arguments | |
Returns: | |
Tensor of CLIP principle components | |
""" | |
# Calculate how many total batches we need | |
num_batches = math.ceil(total_samples / batch_size) | |
# Randomly sample prompts (with replacement if needed) | |
sampled_prompts_clip = random.choices(diverse_prompts, k=num_batches) | |
clip_features_path = f"{params['savepath_training_images']}/clip_principle_directions.pt" | |
if os.path.exists(clip_features_path): | |
df = pd.read_csv(f"{params['savepath_training_images']}/training_data.csv") | |
prompts_training = list(df.prompt) | |
image_paths = list(df.image_path) | |
return torch.load(clip_features_path).to(device), prompts_training, image_paths | |
os.makedirs(params['savepath_training_images'], exist_ok=True) | |
# Generate images and extract features | |
img_idx = 0 | |
clip_features = [] | |
image_paths = [] | |
prompts_training = [] | |
print('Calculating Semantic PCA') | |
for prompt in tqdm(sampled_prompts_clip): | |
if 'max_sequence_length' in params: | |
images = pipe(prompt, | |
num_images_per_prompt = batch_size, | |
num_inference_steps = params['max_denoising_steps'], | |
guidance_scale=guidance_scale, | |
max_sequence_length = params['max_sequence_length'], | |
height = params['height'], | |
width = params['width'], | |
).images | |
else: | |
images = pipe(prompt, | |
num_images_per_prompt = batch_size, | |
num_inference_steps = params['max_denoising_steps'], | |
guidance_scale=guidance_scale, | |
height = params['height'], | |
width = params['width'], | |
).images | |
# Process images | |
clip_inputs = clip_processor(images=images, return_tensors="pt", padding=True) | |
pixel_values = clip_inputs['pixel_values'].to(device) | |
# Get image embeddings | |
with torch.no_grad(): | |
image_features = clip_model.get_image_features(pixel_values) | |
# Normalize embeddings | |
clip_feats = image_features / image_features.norm(dim=1, keepdim=True) | |
clip_features.append(clip_feats) | |
for im in images: | |
image_path = f"{params['savepath_training_images']}/{img_idx}.png" | |
im.save(image_path) | |
image_paths.append(image_path) | |
prompts_training.append(prompt) | |
img_idx += 1 | |
clip_features = torch.cat(clip_features) | |
# Calculate principle components | |
pca = PCA(n_components=num_pca_components) | |
clip_embeds_np = clip_features.float().cpu().numpy() | |
pca.fit(clip_embeds_np) | |
clip_principles = torch.from_numpy(pca.components_).to(device, dtype=pipe.vae.dtype) | |
# Save results | |
torch.save(clip_principles, clip_features_path) | |
pd.DataFrame({ | |
'prompt': prompts_training, | |
'image_path': image_paths | |
}).to_csv(f"{params['savepath_training_images']}/training_data.csv", index=False) | |
return clip_principles, prompts_training, image_paths |