Littlehongman's picture
log: test path
26d3bd8
raw
history blame
4.23 kB
import torch
import torch.nn as nn
import wandb
import streamlit as st
import os
import clip
from transformers import GPT2Tokenizer, GPT2LMHeadModel
class ImageEncoder(nn.Module):
def __init__(self, base_network):
super(ImageEncoder, self).__init__()
self.base_network = base_network
self.embedding_size = self.base_network.token_embedding.weight.shape[1]
def forward(self, images):
with torch.no_grad():
x = self.base_network.encode_image(images)
x = x / x.norm(dim=1, keepdim=True)
x = x.float()
return x
class Mapping(nn.Module):
# Map the featureMap from CLIP model to GPT2
def __init__(self, clip_embedding_size, gpt_embedding_size, length=30): # length: sentence length
super(Mapping, self).__init__()
self.clip_embedding_size = clip_embedding_size
self.gpt_embedding_size = gpt_embedding_size
self.length = length
self.fc1 = nn.Linear(clip_embedding_size, gpt_embedding_size * length)
def forward(self, x):
x = self.fc1(x)
return x.view(-1, self.length, self.gpt_embedding_size)
class TextDecoder(nn.Module):
def __init__(self, base_network):
super(TextDecoder, self).__init__()
self.base_network = base_network
self.embedding_size = self.base_network.transformer.wte.weight.shape[1]
self.vocab_size = self.base_network.transformer.wte.weight.shape[0]
def forward(self, concat_embedding, mask=None):
return self.base_network(inputs_embeds=concat_embedding, attention_mask=mask)
def get_embedding(self, texts):
return self.base_network.transformer.wte(texts)
import pytorch_lightning as pl
class ImageCaptioner(pl.LightningModule):
def __init__(self, clip_model, gpt_model, tokenizer, total_steps, max_length=20):
super(ImageCaptioner, self).__init__()
self.padding_token_id = tokenizer.pad_token_id
#self.stop_token_id = tokenizer.encode('.')[0]
# Define networks
self.clip = ImageEncoder(clip_model)
self.gpt = TextDecoder(gpt_model)
self.mapping_network = Mapping(self.clip.embedding_size, self.gpt.embedding_size, max_length)
# Define variables
self.total_steps = total_steps
self.max_length = max_length
self.clip_embedding_size = self.clip.embedding_size
self.gpt_embedding_size = self.gpt.embedding_size
self.gpt_vocab_size = self.gpt.vocab_size
def forward(self, images, texts, masks):
texts_embedding = self.gpt.get_embedding(texts)
images_embedding = self.clip(images)
images_projection = self.mapping_network(images_embedding).view(-1, self.max_length, self.gpt_embedding_size)
embedding_concat = torch.cat((images_projection, texts_embedding), dim=1)
out = self.gpt(embedding_concat, masks)
return out
# @st.cache_resource
# def download_trained_model():
# wandb.init(anonymous="must")
# api = wandb.Api()
# artifact = api.artifact('hungchiehwu/CLIP-L14_GPT/model-ql03493w:v3')
# artifact_dir = artifact.download()
# wandb.finish()
# return artifact_dir
@st.cache_resource
def load_clip_model():
clip_model, image_transform = clip.load("ViT-L/14", device="cpu")
return clip_model, image_transform
@st.cache_resource
def load_gpt_model():
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt_model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
return gpt_model, tokenizer
@st.cache_resource
def load_model():
# # Load fine-tuned model from wandb
artifact_dir = "./artifacts/model-ql03493w:v3"
PATH = f"{artifact_dir[2:]}/model.ckpt"
# Load pretrained GPT, CLIP model from OpenAI
clip_model, image_transfrom = load_clip_model()
gpt_model, tokenizer = load_gpt_model()
# Load weights
print(PATH)
print(os.getcwd())
model = ImageCaptioner(clip_model, gpt_model, tokenizer, 0)
checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["state_dict"])
return model, image_transfrom, tokenizer