InfoNCE β vit_base_patch16_224 / DataComp-large
Checkpoint for InfoNCE, a contrastive vision-language pretraining baseline using the symmetric InfoNCE (CLIP-style) objective over a BN-MLP projection head. Trained on DataComp-large for 200,000 steps with batch size 4,096.
Model summary
| Property | Value |
|---|---|
| Method | InfoNCE |
| Vision encoder | vit_base_patch16_224 (timm) |
| Text encoder | GPT-2 (12L / 12H / 768D) |
| Embedding dim | 512 |
| Projection head | LinearβBNβGELUβLinear (width 2048) |
| Training objective | Symmetric InfoNCE (CLIP-style) contrastive loss |
| Training data | DataComp-large |
| Training steps | 200,000 |
Usage
import torch
import timm
from transformers import GPT2Config, GPT2Model, AutoTokenizer
from safetensors.torch import load_file
from torchvision.ops import MLP
import torch.nn as nn
HIDDEN = 768
EMBED = 512
# ββ Vision encoder βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
vision_encoder = timm.create_model(
"vit_base_patch16_224", pretrained=False, num_classes=0, dynamic_img_size=True
)
vision_pre_proj = nn.Sequential(
nn.Linear(HIDDEN, 2048), nn.BatchNorm1d(2048), nn.GELU(), nn.Linear(2048, EMBED)
)
# ββ Text encoder βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
def tokenize_with_eos_readout(tokenizer, text, max_length=77):
ids = tokenizer(
text,
add_special_tokens=False,
truncation=True,
max_length=max_length - 1,
)["input_ids"] + [tokenizer.eos_token_id]
pad_len = max_length - len(ids)
input_ids = torch.tensor([ids + [tokenizer.pad_token_id] * pad_len])
attention_mask = torch.tensor([[1] * len(ids) + [0] * pad_len])
return dict(input_ids=input_ids, attention_mask=attention_mask)
def last_unmasked_token(hidden, attention_mask):
lengths = attention_mask.sum(dim=1).clamp(min=1).long()
gather_idx = (lengths - 1).view(-1, 1, 1).expand(-1, 1, hidden.size(-1))
return hidden.gather(1, gather_idx).squeeze(1)
text_encoder = GPT2Model(GPT2Config(
n_embd=HIDDEN, n_layer=12, n_head=12,
n_inner=HIDDEN * 4, vocab_size=tokenizer.vocab_size,
attn_pdrop=0.0, resid_pdrop=0.0, embd_pdrop=0.0,
))
text_pre_proj = nn.Sequential(
nn.Linear(HIDDEN, 2048), nn.BatchNorm1d(2048), nn.GELU(), nn.Linear(2048, EMBED)
)
# ββ Load weights βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
from huggingface_hub import hf_hub_download
vision_weights = load_file(hf_hub_download("lukaskuhndkfz/InfoNCE-ViT-B-DataComp-200k", "vision_encoder.safetensors"))
text_weights = load_file(hf_hub_download("lukaskuhndkfz/InfoNCE-ViT-B-DataComp-200k", "text_encoder.safetensors"))
vision_encoder.load_state_dict({k[len("encoder."):]: v for k, v in vision_weights.items() if k.startswith("encoder.")})
vision_pre_proj.load_state_dict({k[len("pre_proj."):]: v for k, v in vision_weights.items() if k.startswith("pre_proj.")})
text_encoder.load_state_dict({k[len("encoder."):]: v for k, v in text_weights.items() if k.startswith("encoder.")})
text_pre_proj.load_state_dict({k[len("pre_proj."):]: v for k, v in text_weights.items() if k.startswith("pre_proj.")})
vision_encoder.eval()
text_encoder.eval()
# ββ Encode an image ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
from torchvision import transforms
from PIL import Image
transform = transforms.Compose([
transforms.Resize(224), transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image = Image.open("image.jpg").convert("RGB")
pixel_values = transform(image).unsqueeze(0)
with torch.no_grad():
image_features = vision_pre_proj(vision_encoder(pixel_values)) # (1, 512)
# ββ Encode a caption βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
inputs = tokenize_with_eos_readout(tokenizer, "a photo of a cat")
with torch.no_grad():
hidden = text_encoder(**inputs).last_hidden_state
text_hidden = last_unmasked_token(hidden, inputs["attention_mask"])
text_features = text_pre_proj(text_hidden) # (1, 512)
Files
| File | Contents |
|---|---|
vision_encoder.safetensors |
Vision encoder (encoder.*), pre-projection head (pre_proj.*), and cross-modal projector MLP (projector.*) |
text_encoder.safetensors |
Text encoder (encoder.*), pre-projection head (pre_proj.*), and cross-modal projector MLP (projector.*) |
config.json |
Architecture and training hyperparameters |
- Downloads last month
- 49
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support