File size: 5,189 Bytes
9c48aca
 
 
 
 
d7298ca
f0e42ed
d24b09d
9c48aca
 
 
 
 
 
 
d908a2b
9c48aca
 
 
d24b09d
 
9c48aca
e40af41
 
 
 
 
 
 
 
 
 
 
 
 
9c48aca
 
 
e40af41
d7298ca
a4b4fff
9c48aca
 
d7298ca
 
fd892c6
e40af41
9c48aca
19d6216
9c48aca
19d6216
ea37b8e
 
31a9142
9c48aca
31a9142
2b5d0a7
19d6216
 
 
 
24be625
19d6216
 
d24b09d
ea37b8e
 
 
d24b09d
2b5d0a7
ea37b8e
19d6216
 
 
 
 
 
ea37b8e
19d6216
 
ea37b8e
19d6216
 
4e448ea
 
ea37b8e
d908a2b
 
 
 
 
 
4e448ea
429d535
 
 
 
 
 
9c48aca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b5d0a7
19d6216
24be625
9c48aca
 
 
 
19d6216
9c48aca
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import peft
from peft import LoraConfig
from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
import torch
from peft import PeftModel
import torch.nn as nn
import whisperx

clip_model_name = "openai/clip-vit-base-patch32"
phi_model_name  = "microsoft/phi-2"
tokenizer  = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
processor  = AutoProcessor.from_pretrained(clip_model_name)
tokenizer.pad_token = tokenizer.eos_token
IMAGE_TOKEN_ID = 23893 # token for word comment
QA_TOKEN_ID = 50295 # token for qa
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_embed = 768
phi_embed  = 2560
compute_type = "float16"
audio_batch_size = 16

class SimpleResBlock(nn.Module):
    def __init__(self, phi_embed):
        super().__init__()
        self.pre_norm = nn.LayerNorm(phi_embed)
        self.proj = nn.Sequential(
            nn.Linear(phi_embed, phi_embed),
            nn.GELU(),
            nn.Linear(phi_embed, phi_embed)
        )
    def forward(self, x):
        x = self.pre_norm(x)
        return x + self.proj(x)
        
# models
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
resblock = SimpleResBlock(phi_embed).to(device)
phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
audio_model = whisperx.load_model("tiny", device, compute_type=compute_type)

# load weights
model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
merged_model   = model_to_merge.merge_and_unload()
projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth',map_location=torch.device(device)))
resblock.load_state_dict(torch.load('./model_chkpt/step2_resblock.pth',map_location=torch.device(device)))

def model_generate_ans(img=None,img_audio=None,val_q=None):

    max_generate_length = 100
    val_combined_embeds = []
    
    with torch.no_grad():
    
        # image
        if img is not None:
            image_processed  = processor(images=img, return_tensors="pt").to(device)
            clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
            val_image_embeds = projection(clip_val_outputs)
            val_image_embeds = resblock(val_image_embeds).to(torch.float16)
            
            img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
            img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)

            val_combined_embeds.append(val_image_embeds)
            val_combined_embeds.append(img_token_embeds)

        # audio
        if img_audio is not None:
            audio_result = audio_model.transcribe(img_audio)
            audio_text = ''
            for seg in audio_result['segments']:
                audio_text += seg['text']
            audio_text = audio_text.strip()
            audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
            audio_embeds    = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
            val_combined_embeds.append(audio_embeds)
            
        # text question
        if len(val_q) != 0:
            val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
            val_q_embeds    = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
            val_combined_embeds.append(val_q_embeds)


        if img_audio is not None or len(val_q) != 0: # add QA Token
            
            QA_token_tensor = torch.tensor(QA_TOKEN_ID).to(device)
            QA_token_embeds = merged_model.model.embed_tokens(QA_token_tensor).unsqueeze(0).unsqueeze(0)
            val_combined_embeds.append(QA_token_embeds)
            
        val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
        predicted_caption = merged_model.generate(inputs_embeds=val_combined_embeds,
                                                  max_new_tokens=max_generate_length,
                                                  return_dict_in_generate = True)
    
        predicted_captions_decoded = tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0] 
        predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>", "")  
    
    return predicted_captions_decoded
    

with gr.Blocks() as demo:

    gr.Markdown(
    """
    # Chat with MultiModal GPT !
    Build using combining clip model and phi-2 model.
    """
    )

    # app GUI
    with gr.Row():
        with gr.Column():
            img_input    = gr.Image(label='Image',type="pil")
            img_audio    = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath')
            img_question = gr.Text(label ='Text Query')
        with gr.Column():
            img_answer   = gr.Text(label ='Answer')

    section_btn = gr.Button("Submit")
    section_btn.click(model_generate_ans, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
    
if __name__ == "__main__":
    demo.launch()