File size: 4,817 Bytes
473101c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

import gradio as gr
import peft
from peft import LoraConfig, PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
import torch
from PIL import Image
import requests
import numpy as np
import torch.nn as nn
import whisperx
import ffmpeg, pydub
from pydub import AudioSegment

clip_model_name = "wkcn/TinyCLIP-ViT-61M-32-Text-29M-LAION400M"
phi_model_name  = "microsoft/phi-2"

tokenizer  = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
processor  = AutoProcessor.from_pretrained(clip_model_name)
tokenizer.pad_token = tokenizer.eos_token
IMAGE_TOKEN_ID = 23893 # token for word comment
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_embed = 640
phi_embed  = 2560
compute_type = "float32"
audio_batch_size = 1

import gc

# models
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)

projection = torch.nn.Linear(clip_embed, phi_embed).to(device)

gc.collect()
phi_model = AutoModelForCausalLM.from_pretrained(
    phi_model_name,
    trust_remote_code=True,
     )
audio_model = whisperx.load_model("small", device, compute_type=compute_type)

# load weights
model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/')
merged_model = model_to_merge.merge_and_unload().to(device)
projection.load_state_dict(torch.load('./ft_projection.pth',map_location=torch.device(device)))

def inference(img=None,img_audio=None,val_q=None):

    max_generate_length = 50
    val_combined_embeds = []

    with torch.no_grad():

        # image
        if img is not None:
            image_processed  = processor(images=img, return_tensors="pt").to(device)
            clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
            val_image_embeds = projection(clip_val_outputs)

            img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
            img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)

            val_combined_embeds.append(val_image_embeds)
            val_combined_embeds.append(img_token_embeds)

        # audio
        if img_audio is not None:

            # accepting only initial few secs speech
            audio = AudioSegment.from_mp3( img_audio)
            clipped_audio = audio[:20*1000] 
            clipped_audio.export( 'audio.mp3', format="mp3")
            result = audio_model.transcribe('audio.mp3')
            audio_text = ''

            audio_text = result["segments"][0]['text']
            audio_text = audio_text.strip()
            audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
            audio_embeds    = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
            val_combined_embeds.append(audio_embeds)

        # text question
        if len(val_q) != 0:
            val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
            val_q_embeds    = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
            val_combined_embeds.append(val_q_embeds)

        # val_combined_emb
        val_combined_embeds = torch.cat(val_combined_embeds,dim=1)

        predicted_caption = torch.full((1,max_generate_length),50256).to(device)

        for g in range(max_generate_length):
            phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] 
            predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) 
            predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) 
            predicted_caption[:,g] = predicted_word_token.view(1,-1)
            next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) 
            val_combined_embeds   = torch.cat([val_combined_embeds, next_token_embeds], dim=1)

        predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]

    return predicted_captions_decoded

with gr.Blocks() as demo:

    gr.Markdown(
    """
    # multi-modalLLM
    Build using Tiny Clip model and Microsoft's Phi-2 model fine tuned on Instruct 150k.
    """
    )

    # app GUI
    with gr.Row():
        with gr.Column():
            img_input    = gr.Image(label='Reference Image',type="pil")
            img_question = gr.Text(label ='Question related to Image')
            img_audio    = gr.Audio(label="Speak a question", sources=['microphone', 'upload'], type='filepath')            
        with gr.Column():
            img_answer   = gr.Text(label ='Response')

    section_btn = gr.Button("Process")
    section_btn.click(inference, inputs=[img_input,img_audio,img_question], outputs=[img_answer])

if __name__ == "__main__":
    demo.launch(debug=True)