Capstone / app.py
TharunSivamani's picture
application file
5f3e0e0 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import whisperx
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from transformers import CLIPVisionModel, CLIPImageProcessor
import peft
import gradio as gr
device = 'cpu'
model_name = "microsoft/phi-2"
whisper_model = whisperx.load_model('small', device='cpu', compute_type='float32')
image_processor = CLIPImageProcessor.from_pretrained('openai/clip-vit-base-patch32')
clip_model = CLIPVisionModel.from_pretrained('openai/clip-vit-base-patch32')
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.bos_token = tokenizer.eos_token
phi2_model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
device_map = 'cpu'
)
phi2_model.config.use_cache = False
def CLIP_embeddings(image):
_ = clip_model.requires_grad_(False)
image = image_processor(images=image, return_tensors="pt")
image_out = clip_model(image['pixel_values'].to(device=clip_model.device), output_hidden_states=True)
return features(image_out)
def embed_audio(file_name):
result = whisper_model.transcribe(file_name)
res_text = ''
for segment in result['segments']:
res_text = res_text + segment['text']
return res_text.strip()
def features(image_out):
image_features = image_out.hidden_states[-1]
return image_features[:, 1:, :]
def embed_text(text):
input_tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False)
return phi2_model.get_input_embeddings()(input_tokens.input_ids)
class ResBlock(nn.Module):
def __init__(self, input_size):
super().__init__()
self.pre_norm = nn.LayerNorm(input_size)
self.proj = nn.Sequential(
nn.Linear(input_size, input_size),
nn.GELU(),
nn.Linear(input_size, input_size)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
class Projection_Model(nn.Module):
def __init__(
self,
dim_input_CLIP = 768,
dim_input_Phi2 = 2560
):
super(Projection_Model, self).__init__()
self.projection_img = nn.Linear(
dim_input_CLIP, dim_input_Phi2, bias=False
)
self.resblock = ResBlock(dim_input_Phi2)
def forward(self, x):
x = self.projection_img(x)
return self.resblock(x)
model = Projection_Model()
model.projection_img.load_state_dict(torch.load("projection.pth", map_location='cpu'))
model.resblock.load_state_dict(torch.load("block.pth", map_location='cpu'))
def embeddings_image(image):
clip_embeddings = CLIP_embeddings(image)
return model(clip_embeddings)
user = "TharunSivamani"
model_name = "qlora-phi2"
model_id = f"{user}/{model_name}"
phi2_model_peft = peft.PeftModel.from_pretrained(phi2_model, model_id)
def inference(image=None, audio=None, text=None):
if len(text) == 0:
text = None
if image is None and audio is None and text is None:
return None
context = tokenizer("Context: ", return_tensors="pt", return_attention_mask=False)
input_embeds = phi2_model_peft.get_input_embeddings()(context.input_ids)
if image is not None:
query = text
image_embeds = embeddings_image(image)
input_embeds = torch.cat((input_embeds, image_embeds), dim=1)
if audio is not None:
audio_transcribed = embed_audio(audio)
audio_embeds = embed_text(audio_transcribed)
input_embeds = torch.cat((input_embeds, audio_embeds), dim=1)
if text is not None:
query = text
text_embeds = embed_text(text)
input_embeds = torch.cat((input_embeds, text_embeds), dim=1)
question = tokenizer(" Question: " + query, return_tensors="pt", return_attention_mask=False)
question_embeds = phi2_model_peft.get_input_embeddings()(question.input_ids)
input_embeds = torch.cat((input_embeds, question_embeds), dim=1)
answer = tokenizer(" Answer: ", return_tensors="pt", return_attention_mask=False)
answer_embeds = phi2_model_peft.get_input_embeddings()(answer.input_ids)
input_embeds = torch.cat((input_embeds, answer_embeds), dim=1)
result = phi2_model_peft.generate(inputs_embeds=input_embeds, bos_token_id = tokenizer.bos_token_id)
final_ans = tokenizer.batch_decode(result)[0]
final_ans = final_ans.split(tokenizer.eos_token)
if final_ans[0] == '':
return final_ans[1]
else:
return final_ans[0]
demo = gr.Interface(
fn = inference,
inputs = [
gr.Image(label="Image Input"),
gr.Audio(label="Audio Input", sources=["microphone", "upload"], type="filepath"),
gr.Textbox(label="Text Input"),
],
outputs = [
gr.Textbox(label='Answer'),
],
)
demo.launch()