Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,575 Bytes
69defc9 8549414 69defc9 25f023e 6142e6b 8549414 25f023e 69defc9 25f023e 69defc9 25f023e 69defc9 25f023e 69defc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
from typing import Tuple, Union
import os
class ImageEncoder(nn.Module):
def __init__(self, clip_model_name: str = "openai/clip-vit-large-patch14-336"):
"""Initialize the image encoder using CLIP.
Args:
clip_model_name: HuggingFace model name for CLIP
"""
super().__init__()
# Store model name for lazy loading
self.clip_model_name = clip_model_name
self.clip_model = None
self.processor = None
self.valence_head = None
self.arousal_head = None
self.device = None
self._initialized = False
def _ensure_initialized(self):
"""Lazy initialization of the model components."""
if self._initialized:
return
print(f"Initializing ImageEncoder with {self.clip_model_name}...")
print("Loading CLIP model from local cache (network disabled)...")
# Prefer loading strictly from the local Hugging Face cache that `app.py` populates.
# If the files are genuinely missing (e.g. first run without network), we fall back
# to an online download so the user still gets a working application.
# Determine the cache directory from env – this is set in `app.py`.
hf_cache_dir = os.environ.get("HF_HUB_CACHE", None)
try:
self.clip_model = CLIPModel.from_pretrained(
self.clip_model_name,
cache_dir=hf_cache_dir,
local_files_only=True, # use cache only on the first attempt
)
self.processor = CLIPProcessor.from_pretrained(
self.clip_model_name,
cache_dir=hf_cache_dir,
local_files_only=True,
)
print("CLIP model loaded successfully from local cache")
except (OSError, EnvironmentError) as cache_err:
print(
"Local cache for CLIP model not found – attempting a one-time online download..."
)
# Note: this will still respect HF_HUB_CACHE so the files are cached for future runs.
self.clip_model = CLIPModel.from_pretrained(
self.clip_model_name,
cache_dir=hf_cache_dir,
)
self.processor = CLIPProcessor.from_pretrained(
self.clip_model_name,
cache_dir=hf_cache_dir,
)
print("CLIP model downloaded and cached successfully")
print("CLIP model loaded successfully")
# Freeze CLIP parameters
for param in self.clip_model.parameters():
param.requires_grad = False
# Add projection layers for valence and arousal
hidden_dim = self.clip_model.config.projection_dim
projection_dim = hidden_dim // 2
self.valence_head = nn.Sequential(
nn.Linear(hidden_dim, projection_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(projection_dim, projection_dim // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(projection_dim // 2, 1),
nn.Tanh() # Output between -1 and 1
)
self.arousal_head = nn.Sequential(
nn.Linear(hidden_dim, projection_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(projection_dim, projection_dim // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(projection_dim // 2, 1),
nn.Tanh() # Output between -1 and 1
)
# Move model to GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
print(f"Model moved to device: {self.device}")
self._initialized = True
def forward(self, images: Union[Image.Image, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass to get valence and arousal predictions.
Args:
images: Either PIL images or tensors in CLIP format
Returns:
Tuple of predicted valence and arousal scores
"""
# Ensure model is initialized
self._ensure_initialized()
# Process images if they're PIL images
if isinstance(images, Image.Image):
inputs = self.processor(images=images, return_tensors="pt")
pixel_values = inputs.pixel_values.to(self.device)
else:
pixel_values = images.to(self.device)
# Get CLIP image features
image_features = self.clip_model.get_image_features(pixel_values)
# Project to valence and arousal scores
valence = self.valence_head(image_features)
arousal = self.arousal_head(image_features)
return valence, arousal
def encode_image(self, image: Image.Image) -> torch.Tensor:
"""Get the raw CLIP image embeddings.
Args:
image: PIL image to encode
Returns:
Image embedding tensor
"""
# Ensure model is initialized
self._ensure_initialized()
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
image_features = self.clip_model.get_image_features(inputs.pixel_values.to(self.device))
return image_features |