Vasudevakrishna commited on
Commit
7e28176
1 Parent(s): 7514cc7
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import whisperx
3
+ import gradio as gr
4
+ from peft import PeftModel
5
+ from configs import get_config_phase2
6
+ from transformers import AutoTokenizer, AutoProcessor, CLIPVisionModel, AutoModelForCausalLM
7
+
8
+ config = get_config_phase2()
9
+
10
+ clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))
11
+
12
+ base_model = AutoModelForCausalLM.from_pretrained(
13
+ config.get("phi2_model_name"),
14
+ low_cpu_mem_usage=True,
15
+ return_dict=True,
16
+ torch_dtype=torch.float32,
17
+ trust_remote_code=True
18
+ )
19
+
20
+
21
+ ckpts = "ckpts/Qlora_adaptor/"
22
+ phi2_model = PeftModel.from_pretrained(base_model, ckpts)
23
+ phi2_model = phi2_model.merge_and_unload().to(config.get("device"))
24
+
25
+ projection_layer = torch.nn.Linear(config.get("clip_embed"), config.get("phi_embed"))
26
+ projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))
27
+
28
+ # tokenizer
29
+ tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
30
+ processor = AutoProcessor.from_pretrained(config.get("clip_model_name"), trust_remote_code=True)
31
+
32
+ audio_model = whisperx.load_model('tiny', 'cpu', compute_type="float32")
33
+
34
+
35
+ def generate_answers(img=None, aud = None, q = None, max_tokens = 30):
36
+ batch_size = 1
37
+ start_iq = tokenizer.encode("<iQ>")
38
+ end_iq = tokenizer.encode("</iQ>")
39
+ start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
40
+ end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
41
+ start_iq_embeds = phi2_model.model.embed_tokens(start_iq_embeds.to(config.get("device")))
42
+ end_iq_embeds = phi2_model.model.embed_tokens(end_iq_embeds.to(config.get("device")))
43
+
44
+ inputs_embeddings = []
45
+ inputs_embeddings.append(start_iq_embeds)
46
+
47
+ predicted_caption = torch.full((batch_size, max_tokens), 50256, dtype=torch.long, device=config.get('device'))
48
+
49
+ if img is not None:
50
+ images = processor(images=img, return_tensors="pt")['pixel_values'].to(config.get('device'))
51
+ images = {'pixel_values': images.to(config.get("device"))}
52
+ clip_outputs = clip_model(**images)
53
+ # remove cls token
54
+ images = clip_outputs.last_hidden_state[:, 1:, :]
55
+ image_embeddings = projection_layer(images).to(torch.float32)
56
+ inputs_embeddings.append(image_embeddings)
57
+
58
+ if aud is not None:
59
+ trans = audio_model.transcribe(aud)
60
+ audio_res = ""
61
+ for seg in trans['segments']:
62
+ audio_res += seg['text']
63
+ audio_res = audio_res.strip()
64
+ audio_tokens = tokenizer(audio_res,return_tensors="pt", return_attention_mask=False)['input_ids']
65
+ audio_embeds = phi2_model.model.embed_tokens(audio_tokens.to(config.get("device")))
66
+ inputs_embeddings.append(audio_embeds)
67
+
68
+ if q is not None:
69
+ ques = tokenizer(q, return_tensors="pt", return_attention_mask=False)['input_ids']
70
+ q_embeds = phi2_model.model.embed_tokens(ques.to(config.get("device")))
71
+ inputs_embeddings.append(q_embeds)
72
+
73
+ inputs_embeddings.append(end_iq_embeds)
74
+ # Combine embeddings
75
+ combined_embeds = torch.cat(inputs_embeddings, dim=1)
76
+
77
+ for pos in range(max_tokens - 1):
78
+ model_output_logits = phi2_model.forward(inputs_embeds = combined_embeds)['logits']
79
+ predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
80
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
81
+ predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
82
+ next_token_embeds = phi2_model.model.embed_tokens(predicted_word_token)
83
+ combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
84
+ predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
85
+ predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>","")
86
+ return predicted_captions_decoded
87
+
88
+
89
+ with gr.Blocks() as demo:
90
+
91
+ gr.Markdown(
92
+ """
93
+ # TAI2T Model(Text, Audio, Image to Text Model)
94
+ Multimodel GPT with inputs as Image, Audio, Text with output as Text.
95
+ """
96
+ )
97
+
98
+ with gr.Row():
99
+ with gr.Column():
100
+ image = gr.Image(label='Image', type="pil", value=None)
101
+ audio_q = gr.Audio(label="Audio Question", value=None, sources=['microphone', 'upload'], type='filepath')
102
+ question = gr.Text(label ='Question?', value=None)
103
+ max_tokens = gr.Slider(1, 50, value=10, step=1, label="Max tokens")
104
+ with gr.Row():
105
+ answer = gr.Text(label ='Answer')
106
+ with gr.Row():
107
+ submit = gr.Button("Submit")
108
+ submit.click(generate_answers, inputs=[image, audio_q, question, max_tokens], outputs=[answer])
109
+ clear_btn = gr.ClearButton([image, audio_q, question, max_tokens, answer])
110
+
111
+ if __name__ == "__main__":
112
+
113
+ demo.launch(share=True)