File size: 6,797 Bytes
d4e2534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9c0d66
d4e2534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import torch
import torch.nn as nn
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from PIL import Image


class _MLPVectorProjector(nn.Module):
    def __init__(
        self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int
    ):
        super(_MLPVectorProjector, self).__init__()
        self.mlps = nn.ModuleList()
        for _ in range(width):
            mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
            for _ in range(1, num_layers):
                mlp.append(nn.GELU())
                mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False))
            self.mlps.append(nn.Sequential(*mlp))

    def forward(self, x):
        return torch.cat([mlp(x) for mlp in self.mlps], dim=-2)

## Text model

model_name = "microsoft/phi-2"

with torch.no_grad():
    phi2_text = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto",torch_dtype=torch.float16)
     
tokenizer_text = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

## Audio model
model_name_audio = "openai/whisper-small"
pipe = pipeline(task="automatic-speech-recognition", model=model_name_audio, 
    chunk_length_s=30, device="cpu",)

## image model
#Clip model
model_id_clip = "openai/clip-vit-base-patch16"
model_clip = CLIPModel.from_pretrained(model_id_clip).to("cpu")
processor_clip = CLIPProcessor.from_pretrained(model_id_clip)

print('--------------Loaded CLIP----------------------')

# Preprocess the image for clip
def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    image = transforms.Resize((224, 224))(image)
    image = transforms.ToTensor()(image)
    return image.unsqueeze(0)

# Get clip encoding
def encode_image(image_path):
    image = preprocess_image(image_path).to("cpu")    
    # Dummy input_ids for text
    dummy_text = ""
    inputs = processor_clip(text=dummy_text, images=image, return_tensors="pt", padding=True)    
    outputs = model_clip(**inputs)
    img_embedding = outputs.image_embeds
    return img_embedding  

#Get the projection model
img_proj_head = _MLPVectorProjector(512, 2560, 1, 4).to("cpu")
img_proj_head.load_state_dict(torch.load('projection_finetuned.pth',  map_location=torch.device('cpu')))

print('--------------Loaded proj head----------------------')

#Get the fine-tuned phi-2 model
with torch.no_grad():
    phi2_finetuned = AutoModelForCausalLM.from_pretrained(
        "phi2_adaptor_fine_tuned", trust_remote_code=True).to("cpu")

print('--------------Loaded fine tuned phi2 model----------------------')
    

def example_inference(input_text, count, image, img_qn, audio):
    pred_text = textMode(input_text, count)
    pred_text_image = imageMode(image, img_qn)
    pred_text_audio = audioMode(audio)
    return pred_text, pred_text_image, pred_text_audio



def textMode(text, count):
    count = int(count)
    text = "Question: " + text + "Answer: "
    inputs = tokenizer_text(text, return_tensors="pt", return_attention_mask=False)
    prediction = tokenizer_text.batch_decode(
    phi2_finetuned.generate(
        **inputs, 
        max_new_tokens=count,
        bos_token_id=tokenizer_text.bos_token_id, 
        eos_token_id=tokenizer_text.eos_token_id,
        pad_token_id=tokenizer_text.pad_token_id
    )
    )
    return prediction[0].rstrip('<|endoftext|>').rstrip("\n")
        


def imageMode(image, question):
    image_embedding = encode_image(image)
    print('-------Image embedding from clip obtained-----------')
    imgToTextEmb = img_proj_head(image_embedding).unsqueeze(0)
    print('-------text embedding from projection obtained-----------')
    question = "Question: " + question + "Answer: "
    Qtokens = torch.tensor(tokenizer_text.encode(question, add_special_tokens=True)).unsqueeze(0)
    Qtoken_embeddings = phi2_finetuned.get_submodule('model.embed_tokens')(Qtokens)
    print('-------question embedding from phi2 obtained-----------')
    inputs = torch.concat((imgToTextEmb, Qtoken_embeddings), axis=-2)
    
    prediction = tokenizer_text.batch_decode(
        phi2_finetuned.generate(
            inputs_embeds=inputs, 
            max_new_tokens=50,
            bos_token_id=tokenizer_text.bos_token_id, 
            eos_token_id=tokenizer_text.eos_token_id,
            pad_token_id=tokenizer_text.pad_token_id
        )
    )
    text_pred = prediction[0].strip('<|endoftext|>').rstrip("\n")
    return text_pred

def audioMode(audio):
    if audio is None:
        raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
    
    print('---------type of audio--------------')
    print(type(audio))
    print(audio)
    text = pipe(audio, batch_size=8, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"]
    pred_text = textMode(text, 50)

    return pred_text


interface_title = "Multimodal GPT Application"
with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown(f"## **{interface_title}**")
    gr.Markdown("Choose the input mode (text/image/audio) for text generation to chat")
    with gr.Tab("Text mode"):
        text_input = gr.Textbox(placeholder="Enter a prompt", label="Input")
        text_input_count = gr.Textbox(placeholder="Enter number of characters you want to generate", label="Count")
        text_button = gr.Button("Submit")
        text_output = gr.Textbox(label="Chat GPT like text")        
    with gr.Tab("Image mode"):
        with gr.Row():
            image_input = gr.Image(type="filepath")
            image_text_input = gr.Textbox(placeholder="Enter a question/prompt around the image", label="Question/Prompt")
        image_button = gr.Button("Submit")   
        image_text_output = gr.Textbox(label="Answer")
        
    with gr.Tab("Audio mode"):
        audio_input = gr.Audio(type="filepath")
        audio_button = gr.Button("Submit")
        audio_text_output = gr.Textbox(label="Chat GPT like text")
        

    text_button.click(textMode, inputs=[text_input, text_input_count], outputs=text_output)
    image_button.click(imageMode, inputs=[image_input,image_text_input], outputs=image_text_output)
    audio_button.click(audioMode, inputs=audio_input, outputs=audio_text_output)

    gr.Examples(
        examples=[
            ["Briefly explain the geographical features of India?","50","img69.jpg","What is the man behind the counter doing?","audio_ex3.mp3"]            
        ],
        inputs=[text_input, text_input_count, image_input, image_text_input, audio_input],
        outputs=[text_output, image_text_output, audio_text_output],
        fn=example_inference,
    )

demo.launch()