Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import whisperx | |
| from transformers import AutoTokenizer | |
| from transformers import AutoModelForCausalLM | |
| from transformers import CLIPVisionModel, CLIPImageProcessor | |
| import peft | |
| import gradio as gr | |
| device = 'cpu' | |
| user = "VarunSivamani" | |
| model_name = "QLoRA-phi2" | |
| model_id = f"{user}/{model_name}" | |
| model_name = "microsoft/phi-2" | |
| phi2_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| device_map = 'cpu' | |
| ) | |
| phi2_model.config.use_cache = False | |
| whisper_model = whisperx.load_model('small', device='cpu', compute_type='float32') | |
| image_processor = CLIPImageProcessor.from_pretrained('openai/clip-vit-base-patch32') | |
| clip_model = CLIPVisionModel.from_pretrained('openai/clip-vit-base-patch32') | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.bos_token = tokenizer.eos_token | |
| def text_to_embeddings(text): | |
| input_tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False) | |
| return phi2_model.get_input_embeddings()(input_tokens.input_ids) | |
| def audio_to_text_embeds(file_name): | |
| result = whisper_model.transcribe(file_name) | |
| res_text = '' | |
| for segment in result['segments']: | |
| res_text = res_text + segment['text'] | |
| return res_text.strip() | |
| def select_features(image_out): | |
| image_features = image_out.hidden_states[-1] | |
| return image_features[:, 1:, :] | |
| def CLIP_embeddings(image): | |
| _ = clip_model.requires_grad_(False) | |
| image = image_processor(images=image, return_tensors="pt") | |
| image_out = clip_model(image['pixel_values'].to(device=clip_model.device), output_hidden_states=True) | |
| return select_features(image_out) | |
| class ResBlock(nn.Module): | |
| def __init__(self, input_size): | |
| super().__init__() | |
| self.pre_norm = nn.LayerNorm(input_size) | |
| self.proj = nn.Sequential( | |
| nn.Linear(input_size, input_size), | |
| nn.GELU(), | |
| nn.Linear(input_size, input_size) | |
| ) | |
| def forward(self, x): | |
| x = self.pre_norm(x) | |
| return x + self.proj(x) | |
| class CLIP_projection(nn.Module): | |
| def __init__( | |
| self, | |
| dim_input_CLIP = 768, | |
| dim_input_Phi2 = 2560 | |
| ): | |
| super(CLIP_projection, self).__init__() | |
| self.projection_img = nn.Linear( | |
| dim_input_CLIP, dim_input_Phi2, bias=False | |
| ) | |
| self.resblock = ResBlock(dim_input_Phi2) | |
| def forward(self, x): | |
| x = self.projection_img(x) | |
| return self.resblock(x) | |
| proj_layer = CLIP_projection() | |
| proj_layer.projection_img.load_state_dict(torch.load("proj.pth", map_location='cpu')) | |
| proj_layer.resblock.load_state_dict(torch.load("block.pth", map_location='cpu')) | |
| def img_embeddings(image): | |
| clip_embeddings = CLIP_embeddings(image) | |
| return proj_layer(clip_embeddings) | |
| phi2_model_peft = peft.PeftModel.from_pretrained(phi2_model, model_id) | |
| def multimodal_phi2(image=None, audio=None, text=None): | |
| if len(text) == 0: | |
| text = None | |
| if image is None and audio is None and text is None: | |
| return None | |
| context = tokenizer("Context: ", return_tensors="pt", return_attention_mask=False) | |
| input_embeds = phi2_model_peft.get_input_embeddings()(context.input_ids) | |
| if image is not None: | |
| query = text | |
| image_embeds = img_embeddings(image) | |
| input_embeds = torch.cat((input_embeds, image_embeds), dim=1) | |
| if audio is not None: | |
| audio_transcribed = audio_to_text_embeds(audio) | |
| audio_embeds = text_to_embeddings(audio_transcribed) | |
| input_embeds = torch.cat((input_embeds, audio_embeds), dim=1) | |
| if text is not None: | |
| query = text | |
| text_embeds = text_to_embeddings(text) | |
| input_embeds = torch.cat((input_embeds, text_embeds), dim=1) | |
| question = tokenizer(" Question: " + query, return_tensors="pt", return_attention_mask=False) | |
| question_embeds = phi2_model_peft.get_input_embeddings()(question.input_ids) | |
| input_embeds = torch.cat((input_embeds, question_embeds), dim=1) | |
| answer = tokenizer(" Answer: ", return_tensors="pt", return_attention_mask=False) | |
| answer_embeds = phi2_model_peft.get_input_embeddings()(answer.input_ids) | |
| input_embeds = torch.cat((input_embeds, answer_embeds), dim=1) | |
| result = phi2_model_peft.generate(inputs_embeds=input_embeds, bos_token_id = tokenizer.bos_token_id) | |
| process = tokenizer.batch_decode(result)[0] | |
| process = process.split(tokenizer.eos_token) | |
| if process[0] == '': | |
| return process[1] | |
| else: | |
| return process[0] | |
| demo = gr.Interface( | |
| fn=multimodal_phi2, | |
| inputs = [ | |
| gr.Image(label="Image"), | |
| gr.Audio(label="Audio", sources=["microphone", "upload"], type="filepath"), | |
| gr.Textbox(label="Text"), | |
| ], | |
| outputs = [ | |
| gr.Textbox(label='Answer'), | |
| ], | |
| ) | |
| demo.launch() |