File size: 5,575 Bytes
69defc9
 
 
 
 
8549414
69defc9
 
 
 
 
 
 
 
 
 
25f023e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6142e6b
 
8549414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25f023e
 
69defc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25f023e
 
 
69defc9
 
 
 
 
 
 
 
 
 
25f023e
 
 
69defc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25f023e
 
 
69defc9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
from typing import Tuple, Union
import os

class ImageEncoder(nn.Module):
    def __init__(self, clip_model_name: str = "openai/clip-vit-large-patch14-336"):
        """Initialize the image encoder using CLIP.
        
        Args:
            clip_model_name: HuggingFace model name for CLIP
        """
        super().__init__()
        
        # Store model name for lazy loading
        self.clip_model_name = clip_model_name
        self.clip_model = None
        self.processor = None
        self.valence_head = None
        self.arousal_head = None
        self.device = None
        self._initialized = False
    
    def _ensure_initialized(self):
        """Lazy initialization of the model components."""
        if self._initialized:
            return
            
        print(f"Initializing ImageEncoder with {self.clip_model_name}...")
        print("Loading CLIP model from local cache (network disabled)...")

        # Prefer loading strictly from the local Hugging Face cache that `app.py` populates.
        # If the files are genuinely missing (e.g. first run without network), we fall back
        # to an online download so the user still gets a working application.

        # Determine the cache directory from env – this is set in `app.py`.
        hf_cache_dir = os.environ.get("HF_HUB_CACHE", None)

        try:
            self.clip_model = CLIPModel.from_pretrained(
                self.clip_model_name,
                cache_dir=hf_cache_dir,
                local_files_only=True,  # use cache only on the first attempt
            )
            self.processor = CLIPProcessor.from_pretrained(
                self.clip_model_name,
                cache_dir=hf_cache_dir,
                local_files_only=True,
            )
            print("CLIP model loaded successfully from local cache")
        except (OSError, EnvironmentError) as cache_err:
            print(
                "Local cache for CLIP model not found – attempting a one-time online download..."
            )
            # Note: this will still respect HF_HUB_CACHE so the files are cached for future runs.
            self.clip_model = CLIPModel.from_pretrained(
                self.clip_model_name,
                cache_dir=hf_cache_dir,
            )
            self.processor = CLIPProcessor.from_pretrained(
                self.clip_model_name,
                cache_dir=hf_cache_dir,
            )
            print("CLIP model downloaded and cached successfully")
        
        print("CLIP model loaded successfully")
        
        # Freeze CLIP parameters
        for param in self.clip_model.parameters():
            param.requires_grad = False
            
        # Add projection layers for valence and arousal
        hidden_dim = self.clip_model.config.projection_dim
        projection_dim = hidden_dim // 2
        
        self.valence_head = nn.Sequential(
            nn.Linear(hidden_dim, projection_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(projection_dim, projection_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(projection_dim // 2, 1),
            nn.Tanh()  # Output between -1 and 1
        )
        
        self.arousal_head = nn.Sequential(
            nn.Linear(hidden_dim, projection_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(projection_dim, projection_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(projection_dim // 2, 1),
            nn.Tanh()  # Output between -1 and 1
        )
        
        # Move model to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.device)
        
        print(f"Model moved to device: {self.device}")
        self._initialized = True

    def forward(self, images: Union[Image.Image, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass to get valence and arousal predictions.
        
        Args:
            images: Either PIL images or tensors in CLIP format
            
        Returns:
            Tuple of predicted valence and arousal scores
        """
        # Ensure model is initialized
        self._ensure_initialized()
        
        # Process images if they're PIL images
        if isinstance(images, Image.Image):
            inputs = self.processor(images=images, return_tensors="pt")
            pixel_values = inputs.pixel_values.to(self.device)
        else:
            pixel_values = images.to(self.device)
            
        # Get CLIP image features
        image_features = self.clip_model.get_image_features(pixel_values)
        
        # Project to valence and arousal scores
        valence = self.valence_head(image_features)
        arousal = self.arousal_head(image_features)
        
        return valence, arousal
    
    def encode_image(self, image: Image.Image) -> torch.Tensor:
        """Get the raw CLIP image embeddings.
        
        Args:
            image: PIL image to encode
            
        Returns:
            Image embedding tensor
        """
        # Ensure model is initialized
        self._ensure_initialized()
        
        inputs = self.processor(images=image, return_tensors="pt")
        with torch.no_grad():
            image_features = self.clip_model.get_image_features(inputs.pixel_values.to(self.device))
        return image_features