|
|
""" |
|
|
Image Encoder using pre-trained ResNet50. |
|
|
Implements the visual feature extraction module from the paper. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision.models import resnet50, ResNet50_Weights |
|
|
|
|
|
|
|
|
class ImageEncoder(nn.Module): |
|
|
""" |
|
|
Image encoder using ResNet50 with custom final layer. |
|
|
Critical: Final layer initialized with zeros as per paper. |
|
|
""" |
|
|
|
|
|
def __init__(self, config, pretrained_weights_path: str = None): |
|
|
""" |
|
|
Initialize image encoder. |
|
|
|
|
|
Args: |
|
|
config: Configuration object |
|
|
pretrained_weights_path: Path to ResNet50 weights file |
|
|
""" |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.resnet = resnet50(weights=None) |
|
|
|
|
|
|
|
|
if pretrained_weights_path: |
|
|
state_dict = torch.load(pretrained_weights_path, weights_only = False) |
|
|
self.resnet.load_state_dict(state_dict) |
|
|
print(f"Loaded ResNet50 weights from {pretrained_weights_path}") |
|
|
|
|
|
|
|
|
self.resnet.fc = nn.Identity() |
|
|
|
|
|
|
|
|
|
|
|
self.projection = nn.Linear(config.resnet_out_dim, config.hidden_dim) |
|
|
nn.init.zeros_(self.projection.weight) |
|
|
nn.init.zeros_(self.projection.bias) |
|
|
|
|
|
print("Initialized image encoder final layer with zeros") |
|
|
|
|
|
def forward(self, images: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass through ResNet50. |
|
|
|
|
|
Args: |
|
|
images: Input images [batch_size, num_masks, 1, H, W] |
|
|
|
|
|
Returns: |
|
|
Visual features [batch_size, num_masks, hidden_dim] |
|
|
""" |
|
|
batch_size, num_masks, C, H, W = images.shape |
|
|
|
|
|
|
|
|
images_flat = images.view(batch_size * num_masks, C, H, W) |
|
|
|
|
|
|
|
|
if C == 1: |
|
|
images_flat = images_flat.repeat(1, 3, 1, 1) |
|
|
|
|
|
|
|
|
features = self.resnet(images_flat) |
|
|
|
|
|
|
|
|
features = self.projection(features) |
|
|
|
|
|
|
|
|
features = features.view(batch_size, num_masks, self.config.hidden_dim) |
|
|
|
|
|
return features |
|
|
|