multimodalgpt / app.py
sanjanatule's picture
Update app.py
8625fcd verified
raw
history blame
No virus
3.13 kB
import gradio as gr
import peft
from peft import LoraConfig
from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
import torch
from peft import PeftModel
clip_model_name = "openai/clip-vit-base-patch32"
phi_model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(clip_model_name)
tokenizer.pad_token = tokenizer.eos_token
IMAGE_TOKEN_ID = 23893 # token for word comment
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_embed = 768
phi_embed = 2560
# models
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
# load weights
model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
merged_model = model_to_merge.merge_and_unload()
projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth',map_location=torch.device(device)))
def model_generate_ans(img,val_q):
max_generate_length = 100
# image
image_processed = processor(images=img, return_tensors="pt").to(device)
clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
val_image_embeds = projection(clip_val_outputs).to(torch.float16)
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
img_token_embeds = merged_model.model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
val_q_embeds = merged_model.model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
predicted_caption = torch.full((1,max_generate_length),50256)
for g in range(max_generate_length):
phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)
return predicted_captions_decoded
with gr.Blocks() as demo:
gr.Markdown(
"""
# Chat with MultiModal GPT !
Build using combining clip model and phi-2 model.
"""
)
# app GUI
with gr.Row():
with gr.Column():
img_input = gr.Image(label='Image')
img_question = gr.Text(label ='Question')
with gr.Column():
img_answer = gr.Text(label ='Answer')
section_btn = gr.Button("Submit")
section_btn.click(model_generate_ans, inputs=[img_input,img_question], outputs=[img_answer])
if __name__ == "__main__":
demo.launch()