import torch import torch.nn as nn import wandb import streamlit as st import os import clip from transformers import GPT2Tokenizer, GPT2LMHeadModel class ImageEncoder(nn.Module): def __init__(self, base_network): super(ImageEncoder, self).__init__() self.base_network = base_network self.embedding_size = self.base_network.token_embedding.weight.shape[1] def forward(self, images): with torch.no_grad(): x = self.base_network.encode_image(images) x = x / x.norm(dim=1, keepdim=True) x = x.float() return x class Mapping(nn.Module): # Map the featureMap from CLIP model to GPT2 def __init__(self, clip_embedding_size, gpt_embedding_size, length=30): # length: sentence length super(Mapping, self).__init__() self.clip_embedding_size = clip_embedding_size self.gpt_embedding_size = gpt_embedding_size self.length = length self.fc1 = nn.Linear(clip_embedding_size, gpt_embedding_size * length) def forward(self, x): x = self.fc1(x) return x.view(-1, self.length, self.gpt_embedding_size) class TextDecoder(nn.Module): def __init__(self, base_network): super(TextDecoder, self).__init__() self.base_network = base_network self.embedding_size = self.base_network.transformer.wte.weight.shape[1] self.vocab_size = self.base_network.transformer.wte.weight.shape[0] def forward(self, concat_embedding, mask=None): return self.base_network(inputs_embeds=concat_embedding, attention_mask=mask) def get_embedding(self, texts): return self.base_network.transformer.wte(texts) import pytorch_lightning as pl class ImageCaptioner(pl.LightningModule): def __init__(self, clip_model, gpt_model, tokenizer, total_steps, max_length=20): super(ImageCaptioner, self).__init__() self.padding_token_id = tokenizer.pad_token_id #self.stop_token_id = tokenizer.encode('.')[0] # Define networks self.clip = ImageEncoder(clip_model) self.gpt = TextDecoder(gpt_model) self.mapping_network = Mapping(self.clip.embedding_size, self.gpt.embedding_size, max_length) # Define variables self.total_steps = total_steps self.max_length = max_length self.clip_embedding_size = self.clip.embedding_size self.gpt_embedding_size = self.gpt.embedding_size self.gpt_vocab_size = self.gpt.vocab_size def forward(self, images, texts, masks): texts_embedding = self.gpt.get_embedding(texts) images_embedding = self.clip(images) images_projection = self.mapping_network(images_embedding).view(-1, self.max_length, self.gpt_embedding_size) embedding_concat = torch.cat((images_projection, texts_embedding), dim=1) out = self.gpt(embedding_concat, masks) return out # @st.cache_resource # def download_trained_model(): # wandb.init(anonymous="must") # api = wandb.Api() # artifact = api.artifact('hungchiehwu/CLIP-L14_GPT/model-ql03493w:v3') # artifact_dir = artifact.download() # wandb.finish() # return artifact_dir @st.cache_resource def load_clip_model(): clip_model, image_transform = clip.load("ViT-L/14", device="cpu") return clip_model, image_transform @st.cache_resource def load_gpt_model(): tokenizer = GPT2Tokenizer.from_pretrained('gpt2') gpt_model = GPT2LMHeadModel.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token return gpt_model, tokenizer @st.cache_resource def load_model(): # # Load fine-tuned model from wandb artifact_dir = "./artifacts/model-ql03493w:v3" PATH = f"{artifact_dir[2:]}/model.ckpt" # Load pretrained GPT, CLIP model from OpenAI clip_model, image_transfrom = load_clip_model() gpt_model, tokenizer = load_gpt_model() # Load weights print(PATH) print(os.getcwd()) model = ImageCaptioner(clip_model, gpt_model, tokenizer, 0) checkpoint = torch.load(PATH, map_location=torch.device('cpu')) model.load_state_dict(checkpoint["state_dict"]) return model, image_transfrom, tokenizer