File size: 5,073 Bytes
7e28176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d99b746
7e28176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d99b746
7e28176
 
 
 
 
 
 
 
 
 
 
 
7eb436c
7e28176
 
 
 
 
 
 
c5fb1de
7e28176
 
 
c5fb1de
7e28176
 
 
c5fb1de
7e28176
 
0c10fb3
 
 
 
7e28176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5fb1de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import whisperx
import gradio as gr
from peft import PeftModel
from configs import get_config_phase2
from transformers import AutoTokenizer, AutoProcessor, CLIPVisionModel, AutoModelForCausalLM

config = get_config_phase2() 

clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))

base_model = AutoModelForCausalLM.from_pretrained(
    config.get("phi2_model_name"),
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    trust_remote_code=True
)


ckpts = "ckpts/Qlora_adaptor/"
phi2_model = PeftModel.from_pretrained(base_model, ckpts)
phi2_model = phi2_model.merge_and_unload().to(config.get("device"))

projection_layer = torch.nn.Linear(config.get("clip_embed"), config.get("phi_embed"))
projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))

# tokenizer
tokenizer  = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
processor  = AutoProcessor.from_pretrained(config.get("clip_model_name"), trust_remote_code=True)

audio_model = whisperx.load_model('tiny', 'cpu', compute_type="float32")


def generate_answers(img=None, aud = None, q = None, max_tokens = 30):
    batch_size = 1
    start_iq = tokenizer.encode("<iQ>")
    end_iq = tokenizer.encode("</iQ>")
    start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
    end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
    start_iq_embeds = phi2_model.model.embed_tokens(start_iq_embeds.to(config.get("device")))
    end_iq_embeds = phi2_model.model.embed_tokens(end_iq_embeds.to(config.get("device")))
    
    inputs_embeddings = []
    inputs_embeddings.append(start_iq_embeds)

    predicted_caption = torch.full((batch_size, max_tokens), 50256, dtype=torch.long, device=config.get('device'))
    
    if img is not None:
        images = processor(images=img, return_tensors="pt")['pixel_values'].to(config.get('device'))
        images = {'pixel_values': images.to(config.get("device"))}
        clip_outputs = clip_model(**images)
        # remove cls token
        images = clip_outputs.last_hidden_state[:, 1:, :]
        image_embeddings = projection_layer(images).to(torch.float16)
        inputs_embeddings.append(image_embeddings)
    
    if aud is not None:
        trans = audio_model.transcribe(aud)
        audio_res = ""
        for seg in trans['segments']:
            audio_res += seg['text']
        audio_res = audio_res.strip()
        audio_tokens = tokenizer(audio_res,return_tensors="pt", return_attention_mask=False)['input_ids']
        audio_embeds = phi2_model.model.embed_tokens(audio_tokens.to(config.get("device")))
        inputs_embeddings.append(audio_embeds)
        
    if q!='':
        ques = tokenizer(q, return_tensors="pt", return_attention_mask=False)['input_ids']
        q_embeds = phi2_model.model.embed_tokens(ques.to(config.get("device")))
        inputs_embeddings.append(q_embeds)
        
    inputs_embeddings.append(end_iq_embeds)
    # Combine embeddings
    combined_embeds  = torch.cat(inputs_embeddings, dim=1)
    print("----------",combined_embeds.shape)

    for pos in range(max_tokens - 1):
        model_output_logits = phi2_model.forward(inputs_embeds = combined_embeds)['logits']
        print("-=-=-=-", model_output_logits.shape)
        predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
        predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
        predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
        print(predicted_caption)
        next_token_embeds = phi2_model.model.embed_tokens(predicted_word_token)
        combined_embeds   = torch.cat([combined_embeds, next_token_embeds], dim=1)
        del next_token_embeds
        del predicted_word_token
        del predicted_word_token_logits
    del combined_embeds
    predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
    predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>","")
    return predicted_captions_decoded


with gr.Blocks() as demo:

    gr.Markdown(
    """
    # TAI2T Model(Text, Audio, Image to Text Model)
    Multimodel GPT with inputs as Image, Audio, Text with output as Text.
    """
    )

    with gr.Row():
        with gr.Column():
            image = gr.Image(label='Image', type="pil", value=None)
            audio_q = gr.Audio(label="Audio Question", value=None, sources=['microphone', 'upload'], type='filepath')
            question = gr.Text(label ='Question?', value=None)
            max_tokens = gr.Slider(1, 50, value=10, step=1, label="Max tokens")
    with gr.Row():
        answer   = gr.Text(label ='Answer')
    with gr.Row():
        submit = gr.Button("Submit")
        submit.click(generate_answers, inputs=[image, audio_q, question, max_tokens], outputs=[answer])
        clear_btn = gr.ClearButton([image, audio_q, question, max_tokens, answer])
    
if __name__ == "__main__":
    
    demo.launch(share=True, debug=True)