TurkishVLMTAMGAQA / model.py
Mueris's picture
Upload 4 files
3662ec7 verified
import torch
import torch.nn as nn
from transformers import CLIPModel, AutoTokenizer, AutoModelForSeq2SeqLM
from huggingface_hub import hf_hub_download
class CLIP2MT5_CrossAttention(nn.Module):
def __init__(self, clip_name='openai/clip-vit-base-patch32',
t5_name='mukayese/mt5-base-turkish-summarization'):
super().__init__()
self.clip = CLIPModel.from_pretrained(clip_name)
self.tokenizer = AutoTokenizer.from_pretrained(t5_name)
self.t5 = AutoModelForSeq2SeqLM.from_pretrained(t5_name)
self.vis_proj = nn.Linear(
self.clip.config.vision_config.hidden_size,
self.t5.config.d_model
)
def forward(self, images, input_ids, attention_mask, labels=None):
vision_outputs = self.clip.vision_model(pixel_values=images).last_hidden_state
vision_embeds = self.vis_proj(vision_outputs)
text_embeds = self.t5.encoder.embed_tokens(input_ids)
extended_input_embeds = torch.cat([vision_embeds, text_embeds], dim=1)
extended_attention_mask = torch.cat([
torch.ones(vision_embeds.size(0), vision_embeds.size(1),
dtype=attention_mask.dtype, device=attention_mask.device),
attention_mask
], dim=1)
if labels is not None:
labels = labels.clone()
labels[labels == self.tokenizer.pad_token_id] = -100
return self.t5(
inputs_embeds=extended_input_embeds,
attention_mask=extended_attention_mask,
labels=labels,
return_dict=True
)
@torch.no_grad()
def generate(self, images, input_ids, attention_mask, **gen_kwargs):
vision_outputs = self.clip.vision_model(pixel_values=images).last_hidden_state
vision_embeds = self.vis_proj(vision_outputs)
text_embeds = self.t5.encoder.embed_tokens(input_ids)
extended_input_embeds = torch.cat([vision_embeds, text_embeds], dim=1)
extended_attention_mask = torch.cat([
torch.ones(vision_embeds.size(0), vision_embeds.size(1),
dtype=attention_mask.dtype, device=attention_mask.device),
attention_mask
], dim=1)
return self.t5.generate(
inputs_embeds=extended_input_embeds,
attention_mask=extended_attention_mask,
**gen_kwargs
)
# HF Loader for STATE_DICT
def load_model(
repo_id: str,
filename: str = "model.pt",
clip_name="openai/clip-vit-base-patch32",
t5_name="mukayese/mt5-base-turkish-summarization",
device=None
):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
model = CLIP2MT5_CrossAttention(clip_name=clip_name, t5_name=t5_name)
state = torch.load(model_path, map_location=device)
model.load_state_dict(state)
model.to(device)
model.eval()
return model