File size: 4,523 Bytes
9c48aca
 
 
 
 
d7298ca
f0e42ed
d24b09d
9c48aca
 
 
 
 
 
 
 
 
 
d24b09d
 
9c48aca
e40af41
 
 
 
 
 
 
 
 
 
 
 
 
9c48aca
 
 
e40af41
d7298ca
d24b09d
9c48aca
 
d7298ca
 
fd892c6
e40af41
9c48aca
 
 
e9d7857
9c48aca
31a9142
9c48aca
31a9142
 
 
 
 
 
 
d24b09d
 
 
 
 
 
 
 
 
 
 
31a9142
e8fb034
31a9142
 
 
9c48aca
e8fb034
31a9142
 
 
 
 
 
 
 
 
 
9c48aca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import peft
from peft import LoraConfig
from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
import torch
from peft import PeftModel
import torch.nn as nn
import whisperx

clip_model_name = "openai/clip-vit-base-patch32"
phi_model_name  = "microsoft/phi-2"
tokenizer  = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
processor  = AutoProcessor.from_pretrained(clip_model_name)
tokenizer.pad_token = tokenizer.eos_token
IMAGE_TOKEN_ID = 23893 # token for word comment
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_embed = 768
phi_embed  = 2560
compute_type = "float16"
audio_batch_size = 16

class SimpleResBlock(nn.Module):
    def __init__(self, phi_embed):
        super().__init__()
        self.pre_norm = nn.LayerNorm(phi_embed)
        self.proj = nn.Sequential(
            nn.Linear(phi_embed, phi_embed),
            nn.GELU(),
            nn.Linear(phi_embed, phi_embed)
        )
    def forward(self, x):
        x = self.pre_norm(x)
        return x + self.proj(x)
        
# models
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
resblock = SimpleResBlock(phi_embed).to(device)
phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
audio_model = whisperx.load_model("large-v2", device, compute_type=compute_type)

# load weights
model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
merged_model   = model_to_merge.merge_and_unload()
projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth',map_location=torch.device(device)))
resblock.load_state_dict(torch.load('./model_chkpt/step2_resblock.pth',map_location=torch.device(device)))

def model_generate_ans(img,val_q):

    max_generate_length = 30

    with torch.no_grad():
    
        # image
        image_processed  = processor(images=img, return_tensors="pt").to(device)
        clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
        val_image_embeds = projection(clip_val_outputs)
        val_image_embeds = resblock(val_image_embeds).to(torch.float16)
        img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
        img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)

        # audio
        # audio  = whisperx.load_audio(audio_file)
        # result = audio_model.transcribe(audio, batch_size=audio_batch_size)

        # audio_txt = []
        # for s in result["segments"]:
        #    audio_txt.append(s['text'])
        #    print(s['text'])

        # audio_text = "".join(audio_txt)
    
        val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
        val_q_embeds    = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
        
        val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
        
        predicted_caption = torch.full((1,max_generate_length),50256).to(device)
    
        for g in range(max_generate_length):
            phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
            predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
            predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
            predicted_caption[:,g] = predicted_word_token.view(1,-1)
            next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
            val_combined_embeds   = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
            
        predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)
    
    return predicted_captions_decoded
    

with gr.Blocks() as demo:

    gr.Markdown(
    """
    # Chat with MultiModal GPT !
    Build using combining clip model and phi-2 model.
    """
    )

    # app GUI
    with gr.Row():
        with gr.Column():
            img_input    = gr.Image(label='Image')
            img_question = gr.Text(label ='Question')
        with gr.Column():
            img_answer   = gr.Text(label ='Answer')

    section_btn = gr.Button("Submit")
    section_btn.click(model_generate_ans, inputs=[img_input,img_question], outputs=[img_answer])
    
if __name__ == "__main__":
    demo.launch()