Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from torchvision import transforms | |
from transformers import CLIPProcessor, CLIPModel | |
from PIL import Image | |
class _MLPVectorProjector(nn.Module): | |
def __init__( | |
self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int | |
): | |
super(_MLPVectorProjector, self).__init__() | |
self.mlps = nn.ModuleList() | |
for _ in range(width): | |
mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)] | |
for _ in range(1, num_layers): | |
mlp.append(nn.GELU()) | |
mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False)) | |
self.mlps.append(nn.Sequential(*mlp)) | |
def forward(self, x): | |
return torch.cat([mlp(x) for mlp in self.mlps], dim=-2) | |
## Text model | |
model_name = "microsoft/phi-2" | |
with torch.no_grad(): | |
phi2_text = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto",torch_dtype=torch.float16) | |
tokenizer_text = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
## Audio model | |
model_name_audio = "openai/whisper-small" | |
pipe = pipeline(task="automatic-speech-recognition", model=model_name_audio, | |
chunk_length_s=30, device="cpu",) | |
## image model | |
#Clip model | |
model_id_clip = "openai/clip-vit-base-patch16" | |
model_clip = CLIPModel.from_pretrained(model_id_clip).to("cpu") | |
processor_clip = CLIPProcessor.from_pretrained(model_id_clip) | |
print('--------------Loaded CLIP----------------------') | |
# Preprocess the image for clip | |
def preprocess_image(image_path): | |
image = Image.open(image_path).convert("RGB") | |
image = transforms.Resize((224, 224))(image) | |
image = transforms.ToTensor()(image) | |
return image.unsqueeze(0) | |
# Get clip encoding | |
def encode_image(image_path): | |
image = preprocess_image(image_path).to("cpu") | |
# Dummy input_ids for text | |
dummy_text = "" | |
inputs = processor_clip(text=dummy_text, images=image, return_tensors="pt", padding=True) | |
outputs = model_clip(**inputs) | |
img_embedding = outputs.image_embeds | |
return img_embedding | |
#Get the projection model | |
img_proj_head = _MLPVectorProjector(512, 2560, 1, 4).to("cpu") | |
img_proj_head.load_state_dict(torch.load('projection_finetuned.pth', map_location=torch.device('cpu'))) | |
print('--------------Loaded proj head----------------------') | |
#Get the fine-tuned phi-2 model | |
with torch.no_grad(): | |
phi2_finetuned = AutoModelForCausalLM.from_pretrained( | |
"phi2_adaptor_fine_tuned", trust_remote_code=True).to("cpu") | |
print('--------------Loaded fine tuned phi2 model----------------------') | |
def example_inference(input_text, count, image, img_qn, audio): | |
pred_text = textMode(input_text, count) | |
pred_text_image = imageMode(image, img_qn) | |
pred_text_audio = audioMode(audio) | |
return pred_text, pred_text_image, pred_text_audio | |
def textMode(text, count): | |
count = int(count) | |
text = "Question: " + text + "Answer: " | |
inputs = tokenizer_text(text, return_tensors="pt", return_attention_mask=False) | |
prediction = tokenizer_text.batch_decode( | |
phi2_finetuned.generate( | |
**inputs, | |
max_new_tokens=count, | |
bos_token_id=tokenizer_text.bos_token_id, | |
eos_token_id=tokenizer_text.eos_token_id, | |
pad_token_id=tokenizer_text.pad_token_id | |
) | |
) | |
return prediction[0].rstrip('<|endoftext|>').rstrip("\n") | |
def imageMode(image, question): | |
image_embedding = encode_image(image) | |
print('-------Image embedding from clip obtained-----------') | |
imgToTextEmb = img_proj_head(image_embedding).unsqueeze(0) | |
print('-------text embedding from projection obtained-----------') | |
question = "Question: " + question + "Answer: " | |
Qtokens = torch.tensor(tokenizer_text.encode(question, add_special_tokens=True)).unsqueeze(0) | |
Qtoken_embeddings = phi2_finetuned.get_submodule('model.embed_tokens')(Qtokens) | |
print('-------question embedding from phi2 obtained-----------') | |
inputs = torch.concat((imgToTextEmb, Qtoken_embeddings), axis=-2) | |
prediction = tokenizer_text.batch_decode( | |
phi2_finetuned.generate( | |
inputs_embeds=inputs, | |
max_new_tokens=50, | |
bos_token_id=tokenizer_text.bos_token_id, | |
eos_token_id=tokenizer_text.eos_token_id, | |
pad_token_id=tokenizer_text.pad_token_id | |
) | |
) | |
text_pred = prediction[0].strip('<|endoftext|>').rstrip("\n") | |
return text_pred | |
def audioMode(audio): | |
if audio is None: | |
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") | |
print('---------type of audio--------------') | |
print(type(audio)) | |
print(audio) | |
text = pipe(audio, batch_size=8, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"] | |
pred_text = textMode(text, 50) | |
return pred_text | |
interface_title = "Multimodal GPT Application" | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.Markdown(f"## **{interface_title}**") | |
gr.Markdown("Choose the input mode (text/image/audio) for text generation to chat") | |
with gr.Tab("Text mode"): | |
text_input = gr.Textbox(placeholder="Enter a prompt", label="Input") | |
text_input_count = gr.Textbox(placeholder="Enter number of characters you want to generate", label="Count") | |
text_button = gr.Button("Submit") | |
text_output = gr.Textbox(label="Chat GPT like text") | |
with gr.Tab("Image mode"): | |
with gr.Row(): | |
image_input = gr.Image(type="filepath") | |
image_text_input = gr.Textbox(placeholder="Enter a question/prompt around the image", label="Question/Prompt") | |
image_button = gr.Button("Submit") | |
image_text_output = gr.Textbox(label="Answer") | |
with gr.Tab("Audio mode"): | |
audio_input = gr.Audio(type="filepath") | |
audio_button = gr.Button("Submit") | |
audio_text_output = gr.Textbox(label="Chat GPT like text") | |
text_button.click(textMode, inputs=[text_input, text_input_count], outputs=text_output) | |
image_button.click(imageMode, inputs=[image_input,image_text_input], outputs=image_text_output) | |
audio_button.click(audioMode, inputs=audio_input, outputs=audio_text_output) | |
gr.Examples( | |
examples=[ | |
["Briefly explain the geographical features of India?","50","img69.jpg","What is the man behind the counter doing?","audio_ex3.mp3"] | |
], | |
inputs=[text_input, text_input_count, image_input, image_text_input, audio_input], | |
outputs=[text_output, image_text_output, audio_text_output], | |
fn=example_inference, | |
) | |
demo.launch() |