import torch import torch.nn as nn import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from torchvision import transforms from transformers import CLIPProcessor, CLIPModel from PIL import Image class _MLPVectorProjector(nn.Module): def __init__( self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int ): super(_MLPVectorProjector, self).__init__() self.mlps = nn.ModuleList() for _ in range(width): mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)] for _ in range(1, num_layers): mlp.append(nn.GELU()) mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False)) self.mlps.append(nn.Sequential(*mlp)) def forward(self, x): return torch.cat([mlp(x) for mlp in self.mlps], dim=-2) ## Text model model_name = "microsoft/phi-2" with torch.no_grad(): phi2_text = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16) tokenizer_text = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) ## Audio model model_name_audio = "openai/whisper-small" pipe = pipeline(task="automatic-speech-recognition", model=model_name_audio, chunk_length_s=30, device="cpu",) ## image model #Clip model model_id_clip = "openai/clip-vit-base-patch16" model_clip = CLIPModel.from_pretrained(model_id_clip).to("cpu") processor_clip = CLIPProcessor.from_pretrained(model_id_clip) print('--------------Loaded CLIP----------------------') # Preprocess the image for clip def preprocess_image(image_path): image = Image.open(image_path).convert("RGB") image = transforms.Resize((224, 224))(image) image = transforms.ToTensor()(image) return image.unsqueeze(0) # Get clip encoding def encode_image(image_path): image = preprocess_image(image_path).to("cpu") # Dummy input_ids for text dummy_text = "" inputs = processor_clip(text=dummy_text, images=image, return_tensors="pt", padding=True) outputs = model_clip(**inputs) img_embedding = outputs.image_embeds return img_embedding #Get the projection model img_proj_head = _MLPVectorProjector(512, 2560, 1, 4).to("cpu") img_proj_head.load_state_dict(torch.load('projection_finetuned.pth', map_location=torch.device('cpu'))) print('--------------Loaded proj head----------------------') #Get the fine-tuned phi-2 model with torch.no_grad(): phi2_finetuned = AutoModelForCausalLM.from_pretrained( "phi2_adaptor_fineTuned", trust_remote_code=True).to("cpu") print('--------------Loaded fine tuned phi2 model----------------------') def example_inference(input_text, count): #, image, img_qn, audio): pred_text = textMode(input_text, count) return pred_text #, "in progress", "in progress" def textMode(text, count): count = int(count) inputs = tokenizer_text(text, return_tensors="pt", return_attention_mask=False) prediction = tokenizer_text.batch_decode( phi2_text.generate( **inputs, max_new_tokens=count, bos_token_id=tokenizer_text.bos_token_id, eos_token_id=tokenizer_text.eos_token_id, pad_token_id=tokenizer_text.pad_token_id ) ) return prediction[0].rstrip('<|endoftext|>').rstrip("\n") def imageMode(image, question): image_embedding = encode_image(image) print('-------Image embedding from clip obtained-----------') imgToTextEmb = img_proj_head(image_embedding).unsqueeze(0) print('-------text embedding from projection obtained-----------') question = "Question: " + question + "Answer: " Qtokens = torch.tensor(tokenizer_text.encode(question, add_special_tokens=True)).unsqueeze(0) Qtoken_embeddings = phi2_finetuned.get_submodule('model.embed_tokens')(Qtokens) print('-------question embedding from phi2 obtained-----------') inputs = torch.concat((imgToTextEmb, Qtoken_embeddings), axis=-2) prediction = tokenizer_text.batch_decode( phi2_finetuned.generate( inputs_embeds=inputs, max_new_tokens=50, bos_token_id=tokenizer_text.bos_token_id, eos_token_id=tokenizer_text.eos_token_id, pad_token_id=tokenizer_text.pad_token_id ) ) text_pred = prediction[0].strip('<|endoftext|>').rstrip("\n") return text_pred def audioMode(audio): if audio is None: raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") print('---------type of audio--------------') print(type(audio)) print(audio) text = pipe(audio, batch_size=8, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"] pred_text = textMode(text, 50) #sampling_rate = audio[0] #audio_array = audio[1] #print(sampling_rate) #print(audio_array) #input_features = processor_audio(audio_array, sampling_rate=16000, return_tensors="pt").input_features #predicted_ids = model_audio.generate(input_features) #transcription = processor_audio.batch_decode(predicted_ids, skip_special_tokens=True) return pred_text interface_title = "TSAI-ERA-V1 - Capstone - Multimodal GPT Demo" with gr.Blocks() as demo: with gr.Row(): gr.Markdown(f"## **{interface_title}**") gr.Markdown("Choose text mode/image mode/audio mode for generation") with gr.Tab("Text mode"): text_input = gr.Textbox(placeholder="Enter a prompt", label="Input") text_input_count = gr.Textbox(placeholder="Enter number of characters you want to generate", label="Count") text_button = gr.Button("Submit") text_output = gr.Textbox(label="Chat GPT like text") with gr.Tab("Image mode"): with gr.Row(): image_input = gr.Image(type="filepath") image_text_input = gr.Textbox(placeholder="Enter a question/prompt around the image", label="Question/Prompt") image_button = gr.Button("Submit") image_text_output = gr.Textbox(label="Answer") with gr.Tab("Audio mode"): audio_input = gr.Audio(type="filepath") audio_button = gr.Button("Submit") audio_text_output = gr.Textbox(label="Chat GPT like text") text_button.click(textMode, inputs=[text_input, text_input_count], outputs=text_output) image_button.click(imageMode, inputs=[image_input,image_text_input], outputs=image_text_output) audio_button.click(audioMode, inputs=audio_input, outputs=audio_text_output) gr.Examples( examples=[ ["What is a large language model?","50"] #,"","Describe the image",""] ], inputs=[text_input, text_input_count], #, image_input, image_text_input, audio_input], outputs=[text_output], #, image_text_output, audio_text_output], fn=example_inference, ) demo.launch()