File size: 7,155 Bytes
9d0e777
 
064d0ae
6d2bd3e
575a023
 
 
9d0e777
 
 
70f3d32
9d0e777
 
70f3d32
9d0e777
 
 
 
 
 
 
 
 
 
 
a2bdc5d
 
9d0e777
 
0941483
2c55854
833dc32
a2bdc5d
 
 
6d2bd3e
575a023
 
a2bdc5d
 
575a023
 
 
 
 
32bb922
 
575a023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
004680a
aa59afa
575a023
32bb922
 
575a023
0941483
 
 
32bb922
 
575a023
064d0ae
20699cc
28e0e20
20699cc
 
aaf3aee
28e0e20
 
064d0ae
cb536a9
9d0e777
833dc32
a2bdc5d
 
833dc32
9d0e777
38ab966
575a023
 
 
9d0e777
d54999f
38ab966
9d0e777
cb536a9
 
 
575a023
70f3d32
1427d8a
70f3d32
 
e79169b
70f3d32
 
 
 
a704279
7907b26
70f3d32
 
a704279
 
 
70f3d32
 
3718db4
70f3d32
cb536a9
 
6d2bd3e
 
 
 
 
 
 
ae90516
473251f
6d2bd3e
473251f
 
6d2bd3e
 
 
ae90516
cb536a9
 
 
 
 
56c2f3a
cb536a9
 
 
 
 
 
 
 
24edbec
cb536a9
 
 
 
 
2e6c9b3
cb536a9
 
 
 
 
 
 
 
28e0e20
 
20699cc
28e0e20
babb2b2
aaf3aee
28e0e20
 
 
cb536a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import torch
import torch.nn as nn
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from PIL import Image


class _MLPVectorProjector(nn.Module):
    def __init__(
        self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int
    ):
        super(_MLPVectorProjector, self).__init__()
        self.mlps = nn.ModuleList()
        for _ in range(width):
            mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
            for _ in range(1, num_layers):
                mlp.append(nn.GELU())
                mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False))
            self.mlps.append(nn.Sequential(*mlp))

    def forward(self, x):
        return torch.cat([mlp(x) for mlp in self.mlps], dim=-2)

## Text model

model_name = "microsoft/phi-2"

with torch.no_grad():
    phi2_text = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto",torch_dtype=torch.float16)
     
tokenizer_text = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

## Audio model
model_name_audio = "openai/whisper-small"
pipe = pipeline(task="automatic-speech-recognition", model=model_name_audio, 
    chunk_length_s=30, device="cpu",)

## image model
#Clip model
model_id_clip = "openai/clip-vit-base-patch16"
model_clip = CLIPModel.from_pretrained(model_id_clip).to("cpu")
processor_clip = CLIPProcessor.from_pretrained(model_id_clip)

print('--------------Loaded CLIP----------------------')

# Preprocess the image for clip
def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    image = transforms.Resize((224, 224))(image)
    image = transforms.ToTensor()(image)
    return image.unsqueeze(0)

# Get clip encoding
def encode_image(image_path):
    image = preprocess_image(image_path).to("cpu")    
    # Dummy input_ids for text
    dummy_text = ""
    inputs = processor_clip(text=dummy_text, images=image, return_tensors="pt", padding=True)    
    outputs = model_clip(**inputs)
    img_embedding = outputs.image_embeds
    return img_embedding  

#Get the projection model
img_proj_head = _MLPVectorProjector(512, 2560, 1, 4).to("cpu")
img_proj_head.load_state_dict(torch.load('projection_finetuned.pth',  map_location=torch.device('cpu')))

print('--------------Loaded proj head----------------------')

#Get the fine-tuned phi-2 model
with torch.no_grad():
    phi2_finetuned = AutoModelForCausalLM.from_pretrained(
        "phi2_adaptor_fineTuned", trust_remote_code=True).to("cpu")

print('--------------Loaded fine tuned phi2 model----------------------')
    

def example_inference(input_text, count, image, img_qn, audio):
    pred_text = textMode(input_text, count)
    pred_text_image = imageMode(image, img_qn)
    pred_text_audio = audioMode(audio)
    return pred_text, pred_text_image, pred_text_audio



def textMode(text, count):
    count = int(count)
    text = "Question: " + text + "Answer: "
    inputs = tokenizer_text(text, return_tensors="pt", return_attention_mask=False)
    prediction = tokenizer_text.batch_decode(
    phi2_finetuned.generate(
        **inputs, 
        max_new_tokens=count,
        bos_token_id=tokenizer_text.bos_token_id, 
        eos_token_id=tokenizer_text.eos_token_id,
        pad_token_id=tokenizer_text.pad_token_id
    )
    )
    return prediction[0].rstrip('<|endoftext|>').rstrip("\n")
        


def imageMode(image, question):
    image_embedding = encode_image(image)
    print('-------Image embedding from clip obtained-----------')
    imgToTextEmb = img_proj_head(image_embedding).unsqueeze(0)
    print('-------text embedding from projection obtained-----------')
    question = "Question: " + question + "Answer: "
    Qtokens = torch.tensor(tokenizer_text.encode(question, add_special_tokens=True)).unsqueeze(0)
    Qtoken_embeddings = phi2_finetuned.get_submodule('model.embed_tokens')(Qtokens)
    print('-------question embedding from phi2 obtained-----------')
    inputs = torch.concat((imgToTextEmb, Qtoken_embeddings), axis=-2)
    
    prediction = tokenizer_text.batch_decode(
        phi2_finetuned.generate(
            inputs_embeds=inputs, 
            max_new_tokens=50,
            bos_token_id=tokenizer_text.bos_token_id, 
            eos_token_id=tokenizer_text.eos_token_id,
            pad_token_id=tokenizer_text.pad_token_id
        )
    )
    text_pred = prediction[0].strip('<|endoftext|>').rstrip("\n")
    return text_pred

def audioMode(audio):
    if audio is None:
        raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
    
    print('---------type of audio--------------')
    print(type(audio))
    print(audio)
    text = pipe(audio, batch_size=8, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"]
    pred_text = textMode(text, 50)
    #sampling_rate = audio[0]
    #audio_array = audio[1]
    #print(sampling_rate)
    #print(audio_array)
    #input_features = processor_audio(audio_array, sampling_rate=16000, return_tensors="pt").input_features 
    #predicted_ids = model_audio.generate(input_features)
    #transcription = processor_audio.batch_decode(predicted_ids, skip_special_tokens=True)
    return pred_text


interface_title = "TSAI-ERA-V1 - Capstone - Multimodal GPT Demo"
with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown(f"## **{interface_title}**")
    gr.Markdown("Choose text mode/image mode/audio mode for generation")
    with gr.Tab("Text mode"):
        text_input = gr.Textbox(placeholder="Enter a prompt", label="Input")
        text_input_count = gr.Textbox(placeholder="Enter number of characters you want to generate", label="Count")
        text_button = gr.Button("Submit")
        text_output = gr.Textbox(label="Chat GPT like text")        
    with gr.Tab("Image mode"):
        with gr.Row():
            image_input = gr.Image(type="filepath")
            image_text_input = gr.Textbox(placeholder="Enter a question/prompt around the image", label="Question/Prompt")
        image_button = gr.Button("Submit")   
        image_text_output = gr.Textbox(label="Answer")
        
    with gr.Tab("Audio mode"):
        audio_input = gr.Audio(type="filepath")
        audio_button = gr.Button("Submit")
        audio_text_output = gr.Textbox(label="Chat GPT like text")
        

    text_button.click(textMode, inputs=[text_input, text_input_count], outputs=text_output)
    image_button.click(imageMode, inputs=[image_input,image_text_input], outputs=image_text_output)
    audio_button.click(audioMode, inputs=audio_input, outputs=audio_text_output)

    gr.Examples(
        examples=[
            ["What is a large language model?","50","zebras.png","Are the zebras walking or standing still in the image?","WtIsML.m4a"]            
        ],
        inputs=[text_input, text_input_count, image_input, image_text_input, audio_input],
        outputs=[text_output, image_text_output, audio_text_output],
        fn=example_inference,
    )

demo.launch()