VarunSivamani commited on
Commit
a3c3623
1 Parent(s): 8e69ded

application file

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import whisperx
5
+ from transformers import AutoTokenizer
6
+ from transformers import AutoModelForCausalLM
7
+ from transformers import CLIPVisionModel, CLIPImageProcessor
8
+ import peft
9
+ import gradio as gr
10
+
11
+ device = 'cpu'
12
+
13
+ user = "VarunSivamani"
14
+ model_name = "QLoRA-phi2"
15
+ model_id = f"{user}/{model_name}"
16
+
17
+ model_name = "microsoft/phi-2"
18
+ phi2_model = AutoModelForCausalLM.from_pretrained(
19
+ model_name,
20
+ trust_remote_code=True,
21
+ device_map = 'cpu'
22
+ )
23
+ phi2_model.config.use_cache = False
24
+
25
+ whisper_model = whisperx.load_model('small', device='cpu', compute_type='float32')
26
+ image_processor = CLIPImageProcessor.from_pretrained('openai/clip-vit-base-patch32')
27
+ clip_model = CLIPVisionModel.from_pretrained('openai/clip-vit-base-patch32')
28
+
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
30
+ tokenizer.pad_token = tokenizer.eos_token
31
+ tokenizer.bos_token = tokenizer.eos_token
32
+
33
+ def text_to_embeddings(text):
34
+ input_tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False)
35
+ return phi2_model.get_input_embeddings()(input_tokens.input_ids)
36
+
37
+ def audio_to_text_embeds(file_name):
38
+ result = whisper_model.transcribe(file_name)
39
+ res_text = ''
40
+
41
+ for segment in result['segments']:
42
+ res_text = res_text + segment['text']
43
+
44
+ return res_text.strip()
45
+
46
+ def select_features(image_out):
47
+ image_features = image_out.hidden_states[-1]
48
+ return image_features[:, 1:, :]
49
+
50
+ def CLIP_embeddings(image):
51
+ _ = clip_model.requires_grad_(False)
52
+ image = image_processor(images=image, return_tensors="pt")
53
+ image_out = clip_model(image['pixel_values'].to(device=clip_model.device), output_hidden_states=True)
54
+ return select_features(image_out)
55
+
56
+
57
+ class ResBlock(nn.Module):
58
+ def __init__(self, input_size):
59
+ super().__init__()
60
+ self.pre_norm = nn.LayerNorm(input_size)
61
+ self.proj = nn.Sequential(
62
+ nn.Linear(input_size, input_size),
63
+ nn.GELU(),
64
+ nn.Linear(input_size, input_size)
65
+ )
66
+
67
+ def forward(self, x):
68
+ x = self.pre_norm(x)
69
+ return x + self.proj(x)
70
+
71
+ class CLIP_projection(nn.Module):
72
+ def __init__(
73
+ self,
74
+ dim_input_CLIP = 768,
75
+ dim_input_Phi2 = 2560
76
+ ):
77
+ super(CLIP_projection, self).__init__()
78
+ self.projection_img = nn.Linear(
79
+ dim_input_CLIP, dim_input_Phi2, bias=False
80
+ )
81
+ self.resblock = ResBlock(dim_input_Phi2)
82
+
83
+ def forward(self, x):
84
+ x = self.projection_img(x)
85
+ return self.resblock(x)
86
+
87
+
88
+ proj_layer = CLIP_projection()
89
+ proj_layer.projection_img.load_state_dict(torch.load("proj.pth", map_location='cpu'))
90
+ proj_layer.resblock.load_state_dict(torch.load("block.pth", map_location='cpu'))
91
+
92
+
93
+ def img_embeddings(image):
94
+ clip_embeddings = CLIP_embeddings(image)
95
+ return proj_layer(clip_embeddings)
96
+
97
+ phi2_model_peft = peft.PeftModel.from_pretrained(phi2_model, model_id)
98
+
99
+ def multimodal_phi2(image=None, audio=None, text=None):
100
+ if len(text) == 0:
101
+ text = None
102
+
103
+ if image is None and audio is None and text is None:
104
+ return None
105
+
106
+ context = tokenizer("Context: ", return_tensors="pt", return_attention_mask=False)
107
+ input_embeds = phi2_model_peft.get_input_embeddings()(context.input_ids)
108
+
109
+ if image is not None:
110
+ query = text
111
+ image_embeds = img_embeddings(image)
112
+ input_embeds = torch.cat((input_embeds, image_embeds), dim=1)
113
+
114
+ if audio is not None:
115
+ audio_transcribed = audio_to_text_embeds(audio)
116
+ audio_embeds = text_to_embeddings(audio_transcribed)
117
+ input_embeds = torch.cat((input_embeds, audio_embeds), dim=1)
118
+
119
+ if text is not None:
120
+ query = text
121
+ text_embeds = text_to_embeddings(text)
122
+ input_embeds = torch.cat((input_embeds, text_embeds), dim=1)
123
+
124
+ question = tokenizer(" Question: " + query, return_tensors="pt", return_attention_mask=False)
125
+
126
+ question_embeds = phi2_model_peft.get_input_embeddings()(question.input_ids)
127
+ input_embeds = torch.cat((input_embeds, question_embeds), dim=1)
128
+
129
+ answer = tokenizer(" Answer: ", return_tensors="pt", return_attention_mask=False)
130
+ answer_embeds = phi2_model_peft.get_input_embeddings()(answer.input_ids)
131
+ input_embeds = torch.cat((input_embeds, answer_embeds), dim=1)
132
+
133
+ result = phi2_model_peft.generate(inputs_embeds=input_embeds, bos_token_id = tokenizer.bos_token_id)
134
+
135
+ process = tokenizer.batch_decode(result)[0]
136
+ process = process.split(tokenizer.eos_token)
137
+
138
+ if process[0] == '':
139
+ return process[1]
140
+ else:
141
+ return process[0]
142
+
143
+
144
+ demo = gr.Interface(
145
+ fn=multimodal_phi2,
146
+ inputs = [
147
+ gr.Image(label="Image"),
148
+ gr.Audio(label="Audio", sources=["microphone", "upload"], type="filepath"),
149
+ gr.Textbox(label="Text"),
150
+ ],
151
+ outputs = [
152
+ gr.Textbox(label='Answer'),
153
+ ],
154
+ )
155
+
156
+ demo.launch()