File size: 5,437 Bytes
7e28176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b897fc
7e28176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b897fc
7e28176
 
 
 
 
 
 
 
 
 
 
 
7eb436c
7e28176
 
 
 
 
 
 
b9ac069
 
 
 
 
7e28176
b9ac069
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e28176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5fb1de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import whisperx
import gradio as gr
from peft import PeftModel
from configs import get_config_phase2
from transformers import AutoTokenizer, AutoProcessor, CLIPVisionModel, AutoModelForCausalLM

config = get_config_phase2() 

clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))

base_model = AutoModelForCausalLM.from_pretrained(
    config.get("phi2_model_name"),
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float32,
    trust_remote_code=True
)


ckpts = "ckpts/Qlora_adaptor/"
phi2_model = PeftModel.from_pretrained(base_model, ckpts)
phi2_model = phi2_model.merge_and_unload().to(config.get("device"))

projection_layer = torch.nn.Linear(config.get("clip_embed"), config.get("phi_embed"))
projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))

# tokenizer
tokenizer  = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
processor  = AutoProcessor.from_pretrained(config.get("clip_model_name"), trust_remote_code=True)

audio_model = whisperx.load_model('tiny', 'cpu', compute_type="float32")


def generate_answers(img=None, aud = None, q = None, max_tokens = 30):
    batch_size = 1
    start_iq = tokenizer.encode("<iQ>")
    end_iq = tokenizer.encode("</iQ>")
    start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
    end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
    start_iq_embeds = phi2_model.model.embed_tokens(start_iq_embeds.to(config.get("device")))
    end_iq_embeds = phi2_model.model.embed_tokens(end_iq_embeds.to(config.get("device")))
    
    inputs_embeddings = []
    inputs_embeddings.append(start_iq_embeds)

    predicted_caption = torch.full((batch_size, max_tokens), 50256, dtype=torch.long, device=config.get('device'))
    
    if img is not None:
        images = processor(images=img, return_tensors="pt")['pixel_values'].to(config.get('device'))
        images = {'pixel_values': images.to(config.get("device"))}
        clip_outputs = clip_model(**images)
        # remove cls token
        images = clip_outputs.last_hidden_state[:, 1:, :]
        image_embeddings = projection_layer(images).to(torch.float32)
        inputs_embeddings.append(image_embeddings)
    
    if aud is not None:
        trans = audio_model.transcribe(aud)
        audio_res = ""
        for seg in trans['segments']:
            audio_res += seg['text']
        audio_res = audio_res.strip()
        audio_tokens = tokenizer(audio_res,return_tensors="pt", return_attention_mask=False)['input_ids']
        audio_embeds = phi2_model.model.embed_tokens(audio_tokens.to(config.get("device")))
        inputs_embeddings.append(audio_embeds)
        
    if q!='':
        ques = tokenizer(q, return_tensors="pt", return_attention_mask=False)['input_ids']
        q_embeds = phi2_model.model.embed_tokens(ques.to(config.get("device")))
        inputs_embeddings.append(q_embeds)
        
    inputs_embeddings.append(end_iq_embeds)
    # Combine embeddings
    combined_embeds  = torch.cat(inputs_embeddings, dim=1)
    
    predicted_caption = phi2_model.generate(inputs_embeds=combined_embeds,
                                                  max_new_tokens=max_tokens,
                                                  return_dict_in_generate = True)
    # print("----------",combined_embeds.shape)

    # for pos in range(max_tokens - 1):
    #     model_output_logits = phi2_model.forward(inputs_embeds = combined_embeds)['logits']
    #     print("-=-=-=-", model_output_logits.shape)
    #     predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
    #     predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
    #     predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
    #     print(predicted_caption)
    #     next_token_embeds = phi2_model.model.embed_tokens(predicted_word_token)
    #     combined_embeds   = torch.cat([combined_embeds, next_token_embeds], dim=1)
    #     del next_token_embeds
    #     del predicted_word_token
    #     del predicted_word_token_logits
    # del combined_embeds
    # predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
    predicted_captions_decoded =tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0] 
    predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>","")
    return predicted_captions_decoded


with gr.Blocks() as demo:

    gr.Markdown(
    """
    # TAI2T Model(Text, Audio, Image to Text Model)
    Multimodel GPT with inputs as Image, Audio, Text with output as Text.
    """
    )

    with gr.Row():
        with gr.Column():
            image = gr.Image(label='Image', type="pil", value=None)
            audio_q = gr.Audio(label="Audio Question", value=None, sources=['microphone', 'upload'], type='filepath')
            question = gr.Text(label ='Question?', value=None)
            max_tokens = gr.Slider(1, 50, value=10, step=1, label="Max tokens")
    with gr.Row():
        answer   = gr.Text(label ='Answer')
    with gr.Row():
        submit = gr.Button("Submit")
        submit.click(generate_answers, inputs=[image, audio_q, question, max_tokens], outputs=[answer])
        clear_btn = gr.ClearButton([image, audio_q, question, max_tokens, answer])
    
if __name__ == "__main__":
    
    demo.launch(share=True, debug=True)