i3-CLIP / app.py
FlameF0X's picture
Update app.py
773fa52 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import gradio as gr
from torchvision import transforms
from PIL import Image
from transformers import CLIPTokenizer
from huggingface_hub import hf_hub_download
# ============================================================================
# CONFIGURATION - HARDCODED FOR I3-CLIP ARCHITECTURE
# ============================================================================
D_MODEL = 768
N_RWKV = 12
N_ATTN = 4
N_HEADS = 12
FFN_MULT = 4
MAX_LEN = 77
# ============================================================================
# 1. RWKV CORE (JIT OPTIMIZED)
# ============================================================================
@torch.jit.script
def rwkv_linear_attention(B: int, T: int, C: int,
r: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
w: torch.Tensor, u: torch.Tensor,
state_init: torch.Tensor):
y = torch.zeros_like(v)
state_aa = torch.zeros(B, C, dtype=torch.float32, device=r.device)
state_bb = torch.zeros(B, C, dtype=torch.float32, device=r.device)
state_pp = state_init.clone()
for t in range(T):
rt, kt, vt = r[:, t], k[:, t], v[:, t]
ww = u + state_pp
p = torch.maximum(ww, kt)
e1 = torch.exp(ww - p)
e2 = torch.exp(kt - p)
wkv = (state_aa * e1 + vt * e2) / (state_bb * e1 + e2 + 1e-6)
y[:, t] = wkv
ww = w + state_pp
p = torch.maximum(ww, kt)
e1 = torch.exp(ww - p)
e2 = torch.exp(kt - p)
state_aa = state_aa * e1 + vt * e2
state_bb = state_bb * e1 + e2
state_pp = p
return y
class RWKVTimeMix(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
self.time_decay = nn.Parameter(torch.ones(d_model))
self.time_first = nn.Parameter(torch.ones(d_model))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
self.key = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(d_model, d_model, bias=False)
self.receptance = nn.Linear(d_model, d_model, bias=False)
self.output = nn.Linear(d_model, d_model, bias=False)
def forward(self, x):
B, T, C = x.size()
xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k, v = self.key(xk), self.value(xv)
r = torch.sigmoid(self.receptance(xr))
w, u = -torch.exp(self.time_decay), self.time_first
state_init = torch.full((B, C), -1e30, dtype=torch.float32, device=x.device)
rwkv = rwkv_linear_attention(B, T, C, r, k, v, w, u, state_init)
return self.output(r * rwkv)
class RWKVChannelMix(nn.Module):
def __init__(self, d_model, ffn_mult=4):
super().__init__()
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
hidden_sz = d_model * ffn_mult
self.key = nn.Linear(d_model, hidden_sz, bias=False)
self.receptance = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(hidden_sz, d_model, bias=False)
def forward(self, x):
B, T, C = x.size()
xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = torch.square(torch.relu(self.key(xk)))
return torch.sigmoid(self.receptance(xr)) * self.value(k)
class RWKVBlock(nn.Module):
def __init__(self, d_model):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.att = RWKVTimeMix(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = RWKVChannelMix(d_model)
def forward(self, x):
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
# ============================================================================
# 2. VISION ENCODER
# ============================================================================
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
self.downsample = nn.Identity()
if stride != 1 or in_planes != self.expansion * planes:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.downsample(x)
out = F.relu(out)
return out
class VisionEncoderLarge(nn.Module):
def __init__(self, d_model=768):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64), nn.ReLU(inplace=True),
nn.MaxPool2d(3, 2, 1)
)
self.layer1 = self._make_layer(64, 64, 3)
self.layer2 = self._make_layer(256, 128, 4, stride=2)
self.layer3 = self._make_layer(512, 256, 6, stride=2)
self.layer4 = self._make_layer(1024, 512, 3, stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(2048, d_model)
def _make_layer(self, in_planes, planes, blocks, stride=1):
layers = [Bottleneck(in_planes, planes, stride)]
for _ in range(1, blocks):
layers.append(Bottleneck(planes * 4, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return self.fc(self.avgpool(x).flatten(1))
# ============================================================================
# 3. TEXT ENCODER
# ============================================================================
class HybridTextEncoderLarge(nn.Module):
def __init__(self, vocab_size, d_model=768, n_rwkv=12, n_attn=4, max_len=77):
super().__init__()
self.token_embed = nn.Embedding(vocab_size, d_model)
self.pos_embed = nn.Parameter(torch.zeros(1, max_len, d_model))
self.rwkv_layers = nn.ModuleList([RWKVBlock(d_model) for _ in range(n_rwkv)])
self.attn_layers = nn.ModuleList([
nn.TransformerEncoderLayer(d_model=d_model, nhead=N_HEADS,
dim_feedforward=d_model*4,
batch_first=True, activation="gelu")
for _ in range(n_attn)
])
self.ln_final = nn.LayerNorm(d_model)
def forward(self, x):
x = self.token_embed(x) + self.pos_embed[:, :x.size(1), :]
for layer in self.rwkv_layers: x = layer(x)
for layer in self.attn_layers: x = layer(x)
return self.ln_final(x[:, -1, :])
# ============================================================================
# 4. WRAPPER
# ============================================================================
class i3CLIPHybridLarge(nn.Module):
def __init__(self, vocab_size, d_model=768):
super().__init__()
self.visual = VisionEncoderLarge(d_model=d_model)
self.textual = HybridTextEncoderLarge(vocab_size=vocab_size, d_model=d_model)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, images, texts):
img_f = F.normalize(self.visual(images), dim=-1)
txt_f = F.normalize(self.textual(texts), dim=-1)
scale = self.logit_scale.exp()
logits = scale * img_f @ txt_f.t()
return logits, logits.t()
# ============================================================================
# 5. INFERENCE LOGIC
# ============================================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
model = i3CLIPHybridLarge(tokenizer.vocab_size).to(device)
# Load checkpoint from FlameF0X/i3-CLIP
print("Downloading and loading model weights...")
checkpoint_path = hf_hub_download(repo_id="i3-lab/i3-CLIP", filename="pytorch_model.bin")
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.eval()
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.48, 0.45, 0.40), (0.26, 0.26, 0.27))
])
def predict(image, labels_text):
if image is None: return None
labels = [l.strip() for l in labels_text.split(",")]
# Process image
img_tensor = preprocess(image).unsqueeze(0).to(device)
# Process text
txt_tokens = tokenizer(
labels, padding='max_length', truncation=True, max_length=MAX_LEN, return_tensors="pt"
).input_ids.to(device)
with torch.no_grad():
img_features = F.normalize(model.visual(img_tensor), dim=-1)
txt_features = F.normalize(model.textual(txt_tokens), dim=-1)
logits = (img_features @ txt_features.t()) * model.logit_scale.exp()
probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
return {labels[i]: float(probs[i]) for i in range(len(labels))}
# ============================================================================
# 6. GRADIO INTERFACE
# ============================================================================
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Textbox(label="Candidate Labels (comma separated)", value="a photo of a cat, a photo of a dog, a landscape")
],
outputs=gr.Label(num_top_classes=5),
title="i3-CLIP Hybrid RWKV-Transformer Large",
description="This space uses the i3-CLIP architecture: a ResNet-like Bottleneck Vision Encoder and a Hybrid RWKV-Transformer Text Encoder. Weights are loaded from FlameF0X/i3-CLIP."
)
if __name__ == "__main__":
demo.launch()