|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import os |
|
|
import gradio as gr |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
from transformers import CLIPTokenizer |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
D_MODEL = 768 |
|
|
N_RWKV = 12 |
|
|
N_ATTN = 4 |
|
|
N_HEADS = 12 |
|
|
FFN_MULT = 4 |
|
|
MAX_LEN = 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.script |
|
|
def rwkv_linear_attention(B: int, T: int, C: int, |
|
|
r: torch.Tensor, k: torch.Tensor, v: torch.Tensor, |
|
|
w: torch.Tensor, u: torch.Tensor, |
|
|
state_init: torch.Tensor): |
|
|
y = torch.zeros_like(v) |
|
|
state_aa = torch.zeros(B, C, dtype=torch.float32, device=r.device) |
|
|
state_bb = torch.zeros(B, C, dtype=torch.float32, device=r.device) |
|
|
state_pp = state_init.clone() |
|
|
|
|
|
for t in range(T): |
|
|
rt, kt, vt = r[:, t], k[:, t], v[:, t] |
|
|
ww = u + state_pp |
|
|
p = torch.maximum(ww, kt) |
|
|
e1 = torch.exp(ww - p) |
|
|
e2 = torch.exp(kt - p) |
|
|
wkv = (state_aa * e1 + vt * e2) / (state_bb * e1 + e2 + 1e-6) |
|
|
y[:, t] = wkv |
|
|
|
|
|
ww = w + state_pp |
|
|
p = torch.maximum(ww, kt) |
|
|
e1 = torch.exp(ww - p) |
|
|
e2 = torch.exp(kt - p) |
|
|
state_aa = state_aa * e1 + vt * e2 |
|
|
state_bb = state_bb * e1 + e2 |
|
|
state_pp = p |
|
|
return y |
|
|
|
|
|
class RWKVTimeMix(nn.Module): |
|
|
def __init__(self, d_model): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.time_decay = nn.Parameter(torch.ones(d_model)) |
|
|
self.time_first = nn.Parameter(torch.ones(d_model)) |
|
|
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model)) |
|
|
self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model)) |
|
|
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model)) |
|
|
self.key = nn.Linear(d_model, d_model, bias=False) |
|
|
self.value = nn.Linear(d_model, d_model, bias=False) |
|
|
self.receptance = nn.Linear(d_model, d_model, bias=False) |
|
|
self.output = nn.Linear(d_model, d_model, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, C = x.size() |
|
|
xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1) |
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) |
|
|
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) |
|
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) |
|
|
k, v = self.key(xk), self.value(xv) |
|
|
r = torch.sigmoid(self.receptance(xr)) |
|
|
w, u = -torch.exp(self.time_decay), self.time_first |
|
|
state_init = torch.full((B, C), -1e30, dtype=torch.float32, device=x.device) |
|
|
rwkv = rwkv_linear_attention(B, T, C, r, k, v, w, u, state_init) |
|
|
return self.output(r * rwkv) |
|
|
|
|
|
class RWKVChannelMix(nn.Module): |
|
|
def __init__(self, d_model, ffn_mult=4): |
|
|
super().__init__() |
|
|
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model)) |
|
|
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model)) |
|
|
hidden_sz = d_model * ffn_mult |
|
|
self.key = nn.Linear(d_model, hidden_sz, bias=False) |
|
|
self.receptance = nn.Linear(d_model, d_model, bias=False) |
|
|
self.value = nn.Linear(hidden_sz, d_model, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, C = x.size() |
|
|
xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1) |
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) |
|
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) |
|
|
k = torch.square(torch.relu(self.key(xk))) |
|
|
return torch.sigmoid(self.receptance(xr)) * self.value(k) |
|
|
|
|
|
class RWKVBlock(nn.Module): |
|
|
def __init__(self, d_model): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(d_model) |
|
|
self.att = RWKVTimeMix(d_model) |
|
|
self.ln2 = nn.LayerNorm(d_model) |
|
|
self.ffn = RWKVChannelMix(d_model) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.att(self.ln1(x)) |
|
|
x = x + self.ffn(self.ln2(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Bottleneck(nn.Module): |
|
|
expansion = 4 |
|
|
def __init__(self, in_planes, planes, stride=1): |
|
|
super(Bottleneck, self).__init__() |
|
|
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(planes) |
|
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) |
|
|
self.bn2 = nn.BatchNorm2d(planes) |
|
|
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) |
|
|
self.bn3 = nn.BatchNorm2d(self.expansion * planes) |
|
|
|
|
|
self.downsample = nn.Identity() |
|
|
if stride != 1 or in_planes != self.expansion * planes: |
|
|
self.downsample = nn.Sequential( |
|
|
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), |
|
|
nn.BatchNorm2d(self.expansion * planes) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
out = F.relu(self.bn1(self.conv1(x))) |
|
|
out = F.relu(self.bn2(self.conv2(out))) |
|
|
out = self.bn3(self.conv3(out)) |
|
|
out += self.downsample(x) |
|
|
out = F.relu(out) |
|
|
return out |
|
|
|
|
|
class VisionEncoderLarge(nn.Module): |
|
|
def __init__(self, d_model=768): |
|
|
super().__init__() |
|
|
self.stem = nn.Sequential( |
|
|
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), |
|
|
nn.BatchNorm2d(64), nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(3, 2, 1) |
|
|
) |
|
|
self.layer1 = self._make_layer(64, 64, 3) |
|
|
self.layer2 = self._make_layer(256, 128, 4, stride=2) |
|
|
self.layer3 = self._make_layer(512, 256, 6, stride=2) |
|
|
self.layer4 = self._make_layer(1024, 512, 3, stride=2) |
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
self.fc = nn.Linear(2048, d_model) |
|
|
|
|
|
def _make_layer(self, in_planes, planes, blocks, stride=1): |
|
|
layers = [Bottleneck(in_planes, planes, stride)] |
|
|
for _ in range(1, blocks): |
|
|
layers.append(Bottleneck(planes * 4, planes)) |
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.stem(x) |
|
|
x = self.layer1(x) |
|
|
x = self.layer2(x) |
|
|
x = self.layer3(x) |
|
|
x = self.layer4(x) |
|
|
return self.fc(self.avgpool(x).flatten(1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HybridTextEncoderLarge(nn.Module): |
|
|
def __init__(self, vocab_size, d_model=768, n_rwkv=12, n_attn=4, max_len=77): |
|
|
super().__init__() |
|
|
self.token_embed = nn.Embedding(vocab_size, d_model) |
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, max_len, d_model)) |
|
|
self.rwkv_layers = nn.ModuleList([RWKVBlock(d_model) for _ in range(n_rwkv)]) |
|
|
self.attn_layers = nn.ModuleList([ |
|
|
nn.TransformerEncoderLayer(d_model=d_model, nhead=N_HEADS, |
|
|
dim_feedforward=d_model*4, |
|
|
batch_first=True, activation="gelu") |
|
|
for _ in range(n_attn) |
|
|
]) |
|
|
self.ln_final = nn.LayerNorm(d_model) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.token_embed(x) + self.pos_embed[:, :x.size(1), :] |
|
|
for layer in self.rwkv_layers: x = layer(x) |
|
|
for layer in self.attn_layers: x = layer(x) |
|
|
return self.ln_final(x[:, -1, :]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class i3CLIPHybridLarge(nn.Module): |
|
|
def __init__(self, vocab_size, d_model=768): |
|
|
super().__init__() |
|
|
self.visual = VisionEncoderLarge(d_model=d_model) |
|
|
self.textual = HybridTextEncoderLarge(vocab_size=vocab_size, d_model=d_model) |
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
|
|
|
def forward(self, images, texts): |
|
|
img_f = F.normalize(self.visual(images), dim=-1) |
|
|
txt_f = F.normalize(self.textual(texts), dim=-1) |
|
|
scale = self.logit_scale.exp() |
|
|
logits = scale * img_f @ txt_f.t() |
|
|
return logits, logits.t() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
|
model = i3CLIPHybridLarge(tokenizer.vocab_size).to(device) |
|
|
|
|
|
|
|
|
print("Downloading and loading model weights...") |
|
|
checkpoint_path = hf_hub_download(repo_id="i3-lab/i3-CLIP", filename="pytorch_model.bin") |
|
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
model.eval() |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.48, 0.45, 0.40), (0.26, 0.26, 0.27)) |
|
|
]) |
|
|
|
|
|
def predict(image, labels_text): |
|
|
if image is None: return None |
|
|
labels = [l.strip() for l in labels_text.split(",")] |
|
|
|
|
|
|
|
|
img_tensor = preprocess(image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
txt_tokens = tokenizer( |
|
|
labels, padding='max_length', truncation=True, max_length=MAX_LEN, return_tensors="pt" |
|
|
).input_ids.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
img_features = F.normalize(model.visual(img_tensor), dim=-1) |
|
|
txt_features = F.normalize(model.textual(txt_tokens), dim=-1) |
|
|
|
|
|
logits = (img_features @ txt_features.t()) * model.logit_scale.exp() |
|
|
probs = F.softmax(logits, dim=-1).cpu().numpy()[0] |
|
|
|
|
|
return {labels[i]: float(probs[i]) for i in range(len(labels))} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Input Image"), |
|
|
gr.Textbox(label="Candidate Labels (comma separated)", value="a photo of a cat, a photo of a dog, a landscape") |
|
|
], |
|
|
outputs=gr.Label(num_top_classes=5), |
|
|
title="i3-CLIP Hybrid RWKV-Transformer Large", |
|
|
description="This space uses the i3-CLIP architecture: a ResNet-like Bottleneck Vision Encoder and a Hybrid RWKV-Transformer Text Encoder. Weights are loaded from FlameF0X/i3-CLIP." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |