MultiModal-Phi2 / app.py
VarunSivamani's picture
application file
a3c3623 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import whisperx
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from transformers import CLIPVisionModel, CLIPImageProcessor
import peft
import gradio as gr
device = 'cpu'
user = "VarunSivamani"
model_name = "QLoRA-phi2"
model_id = f"{user}/{model_name}"
model_name = "microsoft/phi-2"
phi2_model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
device_map = 'cpu'
)
phi2_model.config.use_cache = False
whisper_model = whisperx.load_model('small', device='cpu', compute_type='float32')
image_processor = CLIPImageProcessor.from_pretrained('openai/clip-vit-base-patch32')
clip_model = CLIPVisionModel.from_pretrained('openai/clip-vit-base-patch32')
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.bos_token = tokenizer.eos_token
def text_to_embeddings(text):
input_tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False)
return phi2_model.get_input_embeddings()(input_tokens.input_ids)
def audio_to_text_embeds(file_name):
result = whisper_model.transcribe(file_name)
res_text = ''
for segment in result['segments']:
res_text = res_text + segment['text']
return res_text.strip()
def select_features(image_out):
image_features = image_out.hidden_states[-1]
return image_features[:, 1:, :]
def CLIP_embeddings(image):
_ = clip_model.requires_grad_(False)
image = image_processor(images=image, return_tensors="pt")
image_out = clip_model(image['pixel_values'].to(device=clip_model.device), output_hidden_states=True)
return select_features(image_out)
class ResBlock(nn.Module):
def __init__(self, input_size):
super().__init__()
self.pre_norm = nn.LayerNorm(input_size)
self.proj = nn.Sequential(
nn.Linear(input_size, input_size),
nn.GELU(),
nn.Linear(input_size, input_size)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
class CLIP_projection(nn.Module):
def __init__(
self,
dim_input_CLIP = 768,
dim_input_Phi2 = 2560
):
super(CLIP_projection, self).__init__()
self.projection_img = nn.Linear(
dim_input_CLIP, dim_input_Phi2, bias=False
)
self.resblock = ResBlock(dim_input_Phi2)
def forward(self, x):
x = self.projection_img(x)
return self.resblock(x)
proj_layer = CLIP_projection()
proj_layer.projection_img.load_state_dict(torch.load("proj.pth", map_location='cpu'))
proj_layer.resblock.load_state_dict(torch.load("block.pth", map_location='cpu'))
def img_embeddings(image):
clip_embeddings = CLIP_embeddings(image)
return proj_layer(clip_embeddings)
phi2_model_peft = peft.PeftModel.from_pretrained(phi2_model, model_id)
def multimodal_phi2(image=None, audio=None, text=None):
if len(text) == 0:
text = None
if image is None and audio is None and text is None:
return None
context = tokenizer("Context: ", return_tensors="pt", return_attention_mask=False)
input_embeds = phi2_model_peft.get_input_embeddings()(context.input_ids)
if image is not None:
query = text
image_embeds = img_embeddings(image)
input_embeds = torch.cat((input_embeds, image_embeds), dim=1)
if audio is not None:
audio_transcribed = audio_to_text_embeds(audio)
audio_embeds = text_to_embeddings(audio_transcribed)
input_embeds = torch.cat((input_embeds, audio_embeds), dim=1)
if text is not None:
query = text
text_embeds = text_to_embeddings(text)
input_embeds = torch.cat((input_embeds, text_embeds), dim=1)
question = tokenizer(" Question: " + query, return_tensors="pt", return_attention_mask=False)
question_embeds = phi2_model_peft.get_input_embeddings()(question.input_ids)
input_embeds = torch.cat((input_embeds, question_embeds), dim=1)
answer = tokenizer(" Answer: ", return_tensors="pt", return_attention_mask=False)
answer_embeds = phi2_model_peft.get_input_embeddings()(answer.input_ids)
input_embeds = torch.cat((input_embeds, answer_embeds), dim=1)
result = phi2_model_peft.generate(inputs_embeds=input_embeds, bos_token_id = tokenizer.bos_token_id)
process = tokenizer.batch_decode(result)[0]
process = process.split(tokenizer.eos_token)
if process[0] == '':
return process[1]
else:
return process[0]
demo = gr.Interface(
fn=multimodal_phi2,
inputs = [
gr.Image(label="Image"),
gr.Audio(label="Audio", sources=["microphone", "upload"], type="filepath"),
gr.Textbox(label="Text"),
],
outputs = [
gr.Textbox(label='Answer'),
],
)
demo.launch()