vigraj commited on
Commit
bcb5222
Β·
verified Β·
1 Parent(s): 4c4a7e6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +124 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import peft
3
+ from peft import LoraConfig
4
+ from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
5
+ import torch
6
+ from peft import PeftModel
7
+ import torch.nn as nn
8
+ import whisperx
9
+
10
+ clip_model_name = "wkcn/TinyCLIP-ViT-61M-32-Text-29M-LAION400M"
11
+ phi_model_name = "microsoft/phi-2"
12
+ tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
13
+ processor = AutoProcessor.from_pretrained(clip_model_name)
14
+ tokenizer.pad_token = tokenizer.eos_token
15
+ IMAGE_TOKEN_ID = 23893 # token for word comment
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ clip_embed = 640
18
+ phi_embed = 2560
19
+ compute_type = "float32"
20
+ audio_batch_size = 16
21
+
22
+ class SimpleResBlock(nn.Module):
23
+ def __init__(self, phi_embed):
24
+ super().__init__()
25
+ self.pre_norm = nn.LayerNorm(phi_embed)
26
+ self.proj = nn.Sequential(
27
+ nn.Linear(phi_embed, phi_embed),
28
+ nn.GELU(),
29
+ nn.Linear(phi_embed, phi_embed)
30
+ )
31
+ def forward(self, x):
32
+ x = self.pre_norm(x)
33
+ return x + self.proj(x)
34
+
35
+ # models
36
+ clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
37
+ projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
38
+ resblock = SimpleResBlock(phi_embed).to(device)
39
+ phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
40
+ audio_model = whisperx.load_model("tiny", device, compute_type=compute_type)
41
+
42
+ # load weights
43
+ model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/qlora_adaptor')
44
+ merged_model = model_to_merge.merge_and_unload()
45
+ projection.load_state_dict(torch.load('./model_chkpt/ft_projection_layer.pth',map_location=torch.device(device)))
46
+ resblock.load_state_dict(torch.load('./model_chkpt/ft_projection_model.pth',map_location=torch.device(device)))
47
+
48
+ def model_generate_ans(img=None,img_audio=None,val_q=None):
49
+
50
+ max_generate_length = 100
51
+ val_combined_embeds = []
52
+
53
+ with torch.no_grad():
54
+
55
+ # image
56
+ if img is not None:
57
+ image_processed = processor(images=img, return_tensors="pt").to(device)
58
+ clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
59
+ val_image_embeds = projection(clip_val_outputs)
60
+ val_image_embeds = resblock(val_image_embeds).to(torch.float16)
61
+
62
+ img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
63
+ img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
64
+
65
+ val_combined_embeds.append(val_image_embeds)
66
+ val_combined_embeds.append(img_token_embeds)
67
+
68
+ # audio
69
+ if img_audio is not None:
70
+ audio_result = audio_model.transcribe(img_audio)
71
+ audio_text = ''
72
+ for seg in audio_result['segments']:
73
+ audio_text += seg['text']
74
+ audio_text = audio_text.strip()
75
+ audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
76
+ audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
77
+ val_combined_embeds.append(audio_embeds)
78
+
79
+ # text question
80
+ if len(val_q) != 0:
81
+ val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
82
+ val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
83
+ val_combined_embeds.append(val_q_embeds)
84
+
85
+ val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
86
+
87
+ #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
88
+ predicted_caption = torch.full((1,max_generate_length),50256).to(device)
89
+
90
+ for g in range(max_generate_length):
91
+ phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
92
+ predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
93
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
94
+ predicted_caption[:,g] = predicted_word_token.view(1,-1)
95
+ next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
96
+ val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
97
+
98
+ predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
99
+
100
+ return predicted_captions_decoded
101
+
102
+
103
+ with gr.Blocks() as demo:
104
+
105
+ gr.Markdown(
106
+ """
107
+ # Chat with MultiModal GPT !
108
+ Build using combining clip model and phi-2 model.
109
+ """
110
+ )
111
+
112
+ # app GUI
113
+ with gr.Row():
114
+ with gr.Column():
115
+ img_input = gr.Image(label='Image',type="pil")
116
+ img_audio = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath')
117
+ img_question = gr.Text(label ='Text Query')
118
+ with gr.Column():
119
+ img_answer = gr.Text(label ='Answer')
120
+
121
+ section_btn = gr.Button("Submit")
122
+ section_btn.click(model_generate_ans, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
123
+
124
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ peft
3
+ accelerate
4
+ transformers==4.37
5
+ einops
6
+ git+https://github.com/m-bain/whisperx.git
7
+ bitsandbytes
8
+ wandb
9
+ ffmpeg
10
+ pydub
11
+ gradio