import torch import torch.nn as nn import torch.nn.functional as F import whisperx from transformers import AutoTokenizer from transformers import AutoModelForCausalLM from transformers import CLIPVisionModel, CLIPImageProcessor import peft import gradio as gr device = 'cpu' model_name = "microsoft/phi-2" whisper_model = whisperx.load_model('small', device='cpu', compute_type='float32') image_processor = CLIPImageProcessor.from_pretrained('openai/clip-vit-base-patch32') clip_model = CLIPVisionModel.from_pretrained('openai/clip-vit-base-patch32') tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) tokenizer.pad_token = tokenizer.eos_token tokenizer.bos_token = tokenizer.eos_token phi2_model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, device_map = 'cpu' ) phi2_model.config.use_cache = False def CLIP_embeddings(image): _ = clip_model.requires_grad_(False) image = image_processor(images=image, return_tensors="pt") image_out = clip_model(image['pixel_values'].to(device=clip_model.device), output_hidden_states=True) return features(image_out) def embed_audio(file_name): result = whisper_model.transcribe(file_name) res_text = '' for segment in result['segments']: res_text = res_text + segment['text'] return res_text.strip() def features(image_out): image_features = image_out.hidden_states[-1] return image_features[:, 1:, :] def embed_text(text): input_tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False) return phi2_model.get_input_embeddings()(input_tokens.input_ids) class ResBlock(nn.Module): def __init__(self, input_size): super().__init__() self.pre_norm = nn.LayerNorm(input_size) self.proj = nn.Sequential( nn.Linear(input_size, input_size), nn.GELU(), nn.Linear(input_size, input_size) ) def forward(self, x): x = self.pre_norm(x) return x + self.proj(x) class Projection_Model(nn.Module): def __init__( self, dim_input_CLIP = 768, dim_input_Phi2 = 2560 ): super(Projection_Model, self).__init__() self.projection_img = nn.Linear( dim_input_CLIP, dim_input_Phi2, bias=False ) self.resblock = ResBlock(dim_input_Phi2) def forward(self, x): x = self.projection_img(x) return self.resblock(x) model = Projection_Model() model.projection_img.load_state_dict(torch.load("projection.pth", map_location='cpu')) model.resblock.load_state_dict(torch.load("block.pth", map_location='cpu')) def embeddings_image(image): clip_embeddings = CLIP_embeddings(image) return model(clip_embeddings) user = "TharunSivamani" model_name = "qlora-phi2" model_id = f"{user}/{model_name}" phi2_model_peft = peft.PeftModel.from_pretrained(phi2_model, model_id) def inference(image=None, audio=None, text=None): if len(text) == 0: text = None if image is None and audio is None and text is None: return None context = tokenizer("Context: ", return_tensors="pt", return_attention_mask=False) input_embeds = phi2_model_peft.get_input_embeddings()(context.input_ids) if image is not None: query = text image_embeds = embeddings_image(image) input_embeds = torch.cat((input_embeds, image_embeds), dim=1) if audio is not None: audio_transcribed = embed_audio(audio) audio_embeds = embed_text(audio_transcribed) input_embeds = torch.cat((input_embeds, audio_embeds), dim=1) if text is not None: query = text text_embeds = embed_text(text) input_embeds = torch.cat((input_embeds, text_embeds), dim=1) question = tokenizer(" Question: " + query, return_tensors="pt", return_attention_mask=False) question_embeds = phi2_model_peft.get_input_embeddings()(question.input_ids) input_embeds = torch.cat((input_embeds, question_embeds), dim=1) answer = tokenizer(" Answer: ", return_tensors="pt", return_attention_mask=False) answer_embeds = phi2_model_peft.get_input_embeddings()(answer.input_ids) input_embeds = torch.cat((input_embeds, answer_embeds), dim=1) result = phi2_model_peft.generate(inputs_embeds=input_embeds, bos_token_id = tokenizer.bos_token_id) final_ans = tokenizer.batch_decode(result)[0] final_ans = final_ans.split(tokenizer.eos_token) if final_ans[0] == '': return final_ans[1] else: return final_ans[0] demo = gr.Interface( fn = inference, inputs = [ gr.Image(label="Image Input"), gr.Audio(label="Audio Input", sources=["microphone", "upload"], type="filepath"), gr.Textbox(label="Text Input"), ], outputs = [ gr.Textbox(label='Answer'), ], ) demo.launch()