File size: 4,892 Bytes
4cbd4f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from typing import List, Optional
import math, random, os
import pandas as pd
import numpy as np
import torch
from tqdm.auto import tqdm
from sklearn.decomposition import PCA


def extract_clip_features(clip, image, encoder):
    """
    Extracts feature embeddings from an image using either CLIP or DINOv2 models.
    
    Args:
        clip (torch.nn.Module): The feature extraction model (either CLIP or DINOv2)
        image (torch.Tensor): Input image tensor normalized according to model requirements
        encoder (str): Type of encoder to use ('dinov2-small' or 'clip')
    
    Returns:
        torch.Tensor: Feature embeddings extracted from the image
        
    Note:
        - For DINOv2 models, uses the pooled output features
        - For CLIP models, uses the image features from the vision encoder
        - The input image should already be properly resized and normalized
    """
    # Handle DINOv2 models
    if 'dino' in encoder:
        denoised = clip(image)
        denoised = denoised.pooler_output
    # Handle CLIP models
    else:
        denoised = clip.get_image_features(image)
    
    return denoised

@torch.no_grad()
def compute_clip_pca(
    diverse_prompts: List[str],
    pipe,
    clip_model,
    clip_processor,
    device,
    guidance_scale,
    params,
    total_samples = 5000,
    num_pca_components = 100,
    batch_size = 10
    
) -> torch.Tensor:
    """
    Extract CLIP features from generated images based on prompts.
    
    Args:
        diverse_prompts: List of prompts to generate images from
        model_components: Various model components needed for generation
        args: Training arguments
        
    Returns:
        Tensor of CLIP principle components
    """
    
    
    # Calculate how many total batches we need
    num_batches = math.ceil(total_samples / batch_size)
    # Randomly sample prompts (with replacement if needed)
    sampled_prompts_clip = random.choices(diverse_prompts, k=num_batches)
    
    clip_features_path = f"{params['savepath_training_images']}/clip_principle_directions.pt"
    
    if os.path.exists(clip_features_path):
        df = pd.read_csv(f"{params['savepath_training_images']}/training_data.csv")
        prompts_training = list(df.prompt)
        image_paths = list(df.image_path)
        return torch.load(clip_features_path).to(device), prompts_training, image_paths
    
    os.makedirs(params['savepath_training_images'], exist_ok=True)
    
    # Generate images and extract features
    img_idx = 0
    clip_features = []
    image_paths = []
    prompts_training = []
    print('Calculating Semantic PCA')
    
    for prompt in tqdm(sampled_prompts_clip):
        if 'max_sequence_length' in params:
            images = pipe(prompt, 
                     num_images_per_prompt = batch_size,
                     num_inference_steps = params['max_denoising_steps'],
                     guidance_scale=guidance_scale,
                     max_sequence_length = params['max_sequence_length'],
                     height = params['height'],
                     width = params['width'],
                     ).images
        else:  
            images = pipe(prompt, 
                         num_images_per_prompt = batch_size,
                         num_inference_steps = params['max_denoising_steps'],
                         guidance_scale=guidance_scale,
                         height = params['height'],
                         width = params['width'],
                         ).images

        
        # Process images
        clip_inputs = clip_processor(images=images, return_tensors="pt", padding=True)
        pixel_values = clip_inputs['pixel_values'].to(device)
        
        # Get image embeddings
        with torch.no_grad():
            image_features = clip_model.get_image_features(pixel_values)
            
        # Normalize embeddings
        clip_feats = image_features / image_features.norm(dim=1, keepdim=True)
        clip_features.append(clip_feats)

        for im in images:
            image_path = f"{params['savepath_training_images']}/{img_idx}.png"
            im.save(image_path)
            image_paths.append(image_path)
            prompts_training.append(prompt)
            img_idx += 1

    
    clip_features = torch.cat(clip_features)

    
    # Calculate principle components
    pca = PCA(n_components=num_pca_components)
    clip_embeds_np = clip_features.float().cpu().numpy()
    pca.fit(clip_embeds_np)
    clip_principles = torch.from_numpy(pca.components_).to(device, dtype=pipe.vae.dtype)
    
    # Save results
    torch.save(clip_principles, clip_features_path)
    pd.DataFrame({
        'prompt': prompts_training,
        'image_path': image_paths
    }).to_csv(f"{params['savepath_training_images']}/training_data.csv", index=False)
    
    return clip_principles, prompts_training, image_paths