wgetdd commited on
Commit
a9e35af
β€’
1 Parent(s): 706433a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +126 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import peft
3
+ from peft import LoraConfig
4
+ from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
5
+ import torch
6
+ from peft import PeftModel
7
+ import torch.nn as nn
8
+ import whisperx
9
+
10
+ clip_model_name = "openai/clip-vit-base-patch32"
11
+ phi_model_name = "microsoft/phi-2"
12
+ tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
13
+ processor = AutoProcessor.from_pretrained(clip_model_name)
14
+ tokenizer.pad_token = tokenizer.eos_token
15
+ IMAGE_TOKEN_ID = 23893 # token for word comment
16
+ QA_TOKEN_ID = 50295 # token for qa
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ clip_embed = 768
19
+ phi_embed = 2560
20
+ compute_type = "float16"
21
+ audio_batch_size = 16
22
+
23
+ class SimpleResBlock(nn.Module):
24
+ def __init__(self, phi_embed):
25
+ super().__init__()
26
+ self.pre_norm = nn.LayerNorm(phi_embed)
27
+ self.proj = nn.Sequential(
28
+ nn.Linear(phi_embed, phi_embed),
29
+ nn.GELU(),
30
+ nn.Linear(phi_embed, phi_embed)
31
+ )
32
+ def forward(self, x):
33
+ x = self.pre_norm(x)
34
+ return x + self.proj(x)
35
+
36
+ # models
37
+ clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
38
+ projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
39
+ resblock = SimpleResBlock(phi_embed).to(device)
40
+ phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
41
+ audio_model = whisperx.load_model("tiny", device, compute_type=compute_type)
42
+
43
+ # load weights
44
+ model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
45
+ merged_model = model_to_merge.merge_and_unload()
46
+ projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth',map_location=torch.device(device)))
47
+ resblock.load_state_dict(torch.load('./model_chkpt/step2_resblock.pth',map_location=torch.device(device)))
48
+
49
+ def model_generate_ans(img=None,img_audio=None,val_q=None):
50
+
51
+ max_generate_length = 100
52
+ val_combined_embeds = []
53
+
54
+ with torch.no_grad():
55
+
56
+ # image
57
+ if img is not None:
58
+ image_processed = processor(images=img, return_tensors="pt").to(device)
59
+ clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
60
+ val_image_embeds = projection(clip_val_outputs)
61
+ val_image_embeds = resblock(val_image_embeds).to(torch.float16)
62
+
63
+ img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
64
+ img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
65
+
66
+ val_combined_embeds.append(val_image_embeds)
67
+ val_combined_embeds.append(img_token_embeds)
68
+
69
+ # audio
70
+ if img_audio is not None:
71
+ audio_result = audio_model.transcribe(img_audio)
72
+ audio_text = ''
73
+ for seg in audio_result['segments']:
74
+ audio_text += seg['text']
75
+ audio_text = audio_text.strip()
76
+ audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
77
+ audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
78
+ val_combined_embeds.append(audio_embeds)
79
+
80
+ # text question
81
+ if len(val_q) != 0:
82
+ val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
83
+ val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
84
+ val_combined_embeds.append(val_q_embeds)
85
+
86
+
87
+ if img_audio is not None or len(val_q) != 0: # add QA Token
88
+
89
+ QA_token_tensor = torch.tensor(QA_TOKEN_ID).to(device)
90
+ QA_token_embeds = merged_model.model.embed_tokens(QA_token_tensor).unsqueeze(0).unsqueeze(0)
91
+ val_combined_embeds.append(QA_token_embeds)
92
+
93
+ val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
94
+ predicted_caption = merged_model.generate(inputs_embeds=val_combined_embeds,
95
+ max_new_tokens=max_generate_length,
96
+ return_dict_in_generate = True)
97
+
98
+ predicted_captions_decoded = tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0]
99
+ predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>", "")
100
+
101
+ return predicted_captions_decoded
102
+
103
+
104
+ with gr.Blocks() as demo:
105
+
106
+ gr.Markdown(
107
+ """
108
+ # Chat with MultiModal GPT !
109
+ Build using combining clip model and phi-2 model.
110
+ """
111
+ )
112
+
113
+ # app GUI
114
+ with gr.Row():
115
+ with gr.Column():
116
+ img_input = gr.Image(label='Image',type="pil")
117
+ img_audio = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath')
118
+ img_question = gr.Text(label ='Text Query')
119
+ with gr.Column():
120
+ img_answer = gr.Text(label ='Answer')
121
+
122
+ section_btn = gr.Button("Submit")
123
+ section_btn.click(model_generate_ans, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
124
+
125
+ if __name__ == "__main__":
126
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ peft
3
+ accelerate
4
+ transformers
5
+ einops
6
+ git+https://github.com/m-bain/whisperx.git