Spaces:
Sleeping
Sleeping
import gradio as gr | |
import peft | |
from peft import LoraConfig | |
from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor | |
import torch | |
from peft import PeftModel | |
clip_model_name = "openai/clip-vit-base-patch32" | |
phi_model_name = "microsoft/phi-2" | |
tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True) | |
processor = AutoProcessor.from_pretrained(clip_model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
IMAGE_TOKEN_ID = 23893 # token for word comment | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
clip_embed = 768 | |
phi_embed = 2560 | |
class SimpleResBlock(nn.Module): | |
def __init__(self, phi_embed): | |
super().__init__() | |
self.pre_norm = nn.LayerNorm(phi_embed) | |
self.proj = nn.Sequential( | |
nn.Linear(phi_embed, phi_embed), | |
nn.GELU(), | |
nn.Linear(phi_embed, phi_embed) | |
) | |
def forward(self, x): | |
x = self.pre_norm(x) | |
return x + self.proj(x) | |
# models | |
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device) | |
projection = torch.nn.Linear(clip_embed, phi_embed).to(device) | |
resblock = SimpleResBlock(phi_embed).to(device) | |
phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device) | |
# load weights | |
model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor') | |
merged_model = model_to_merge.merge_and_unload() | |
projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth',map_location=torch.device(device))) | |
resblock.load_state_dict(torch.load('./model_chkpt/step2_resblock.pth',map_location=torch.device(device))) | |
def model_generate_ans(img,val_q): | |
max_generate_length = 100 | |
# image | |
image_processed = processor(images=img, return_tensors="pt").to(device) | |
clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:] | |
val_image_embeds = projection(clip_val_outputs) | |
val_image_embeds = resblock(val_image_embeds).to(torch.float16) | |
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device) | |
img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0) | |
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0) | |
val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0) | |
val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560 | |
predicted_caption = torch.full((1,max_generate_length),50256) | |
for g in range(max_generate_length): | |
phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200 | |
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200 | |
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1 | |
predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device) | |
next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560 | |
val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1) | |
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256) | |
return predicted_captions_decoded | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Chat with MultiModal GPT ! | |
Build using combining clip model and phi-2 model. | |
""" | |
) | |
# app GUI | |
with gr.Row(): | |
with gr.Column(): | |
img_input = gr.Image(label='Image') | |
img_question = gr.Text(label ='Question') | |
with gr.Column(): | |
img_answer = gr.Text(label ='Answer') | |
section_btn = gr.Button("Submit") | |
section_btn.click(model_generate_ans, inputs=[img_input,img_question], outputs=[img_answer]) | |
if __name__ == "__main__": | |
demo.launch() |