TharunSivamani commited on
Commit
5f3e0e0
1 Parent(s): 946ede3

application file

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import whisperx
5
+ from transformers import AutoTokenizer
6
+ from transformers import AutoModelForCausalLM
7
+ from transformers import CLIPVisionModel, CLIPImageProcessor
8
+ import peft
9
+ import gradio as gr
10
+
11
+ device = 'cpu'
12
+ model_name = "microsoft/phi-2"
13
+ whisper_model = whisperx.load_model('small', device='cpu', compute_type='float32')
14
+ image_processor = CLIPImageProcessor.from_pretrained('openai/clip-vit-base-patch32')
15
+ clip_model = CLIPVisionModel.from_pretrained('openai/clip-vit-base-patch32')
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
18
+ tokenizer.pad_token = tokenizer.eos_token
19
+ tokenizer.bos_token = tokenizer.eos_token
20
+
21
+ phi2_model = AutoModelForCausalLM.from_pretrained(
22
+ model_name,
23
+ trust_remote_code=True,
24
+ device_map = 'cpu'
25
+ )
26
+ phi2_model.config.use_cache = False
27
+
28
+
29
+ def CLIP_embeddings(image):
30
+ _ = clip_model.requires_grad_(False)
31
+ image = image_processor(images=image, return_tensors="pt")
32
+ image_out = clip_model(image['pixel_values'].to(device=clip_model.device), output_hidden_states=True)
33
+ return features(image_out)
34
+
35
+ def embed_audio(file_name):
36
+ result = whisper_model.transcribe(file_name)
37
+ res_text = ''
38
+
39
+ for segment in result['segments']:
40
+ res_text = res_text + segment['text']
41
+
42
+ return res_text.strip()
43
+
44
+ def features(image_out):
45
+ image_features = image_out.hidden_states[-1]
46
+ return image_features[:, 1:, :]
47
+
48
+ def embed_text(text):
49
+ input_tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False)
50
+ return phi2_model.get_input_embeddings()(input_tokens.input_ids)
51
+
52
+
53
+ class ResBlock(nn.Module):
54
+ def __init__(self, input_size):
55
+ super().__init__()
56
+ self.pre_norm = nn.LayerNorm(input_size)
57
+ self.proj = nn.Sequential(
58
+ nn.Linear(input_size, input_size),
59
+ nn.GELU(),
60
+ nn.Linear(input_size, input_size)
61
+ )
62
+
63
+ def forward(self, x):
64
+ x = self.pre_norm(x)
65
+ return x + self.proj(x)
66
+
67
+ class Projection_Model(nn.Module):
68
+ def __init__(
69
+ self,
70
+ dim_input_CLIP = 768,
71
+ dim_input_Phi2 = 2560
72
+ ):
73
+ super(Projection_Model, self).__init__()
74
+ self.projection_img = nn.Linear(
75
+ dim_input_CLIP, dim_input_Phi2, bias=False
76
+ )
77
+ self.resblock = ResBlock(dim_input_Phi2)
78
+
79
+ def forward(self, x):
80
+ x = self.projection_img(x)
81
+ return self.resblock(x)
82
+
83
+
84
+ model = Projection_Model()
85
+ model.projection_img.load_state_dict(torch.load("projection.pth", map_location='cpu'))
86
+ model.resblock.load_state_dict(torch.load("block.pth", map_location='cpu'))
87
+
88
+
89
+ def embeddings_image(image):
90
+ clip_embeddings = CLIP_embeddings(image)
91
+ return model(clip_embeddings)
92
+
93
+ user = "TharunSivamani"
94
+ model_name = "qlora-phi2"
95
+ model_id = f"{user}/{model_name}"
96
+
97
+ phi2_model_peft = peft.PeftModel.from_pretrained(phi2_model, model_id)
98
+
99
+
100
+ def inference(image=None, audio=None, text=None):
101
+ if len(text) == 0:
102
+ text = None
103
+
104
+ if image is None and audio is None and text is None:
105
+ return None
106
+
107
+ context = tokenizer("Context: ", return_tensors="pt", return_attention_mask=False)
108
+ input_embeds = phi2_model_peft.get_input_embeddings()(context.input_ids)
109
+
110
+ if image is not None:
111
+ query = text
112
+ image_embeds = embeddings_image(image)
113
+ input_embeds = torch.cat((input_embeds, image_embeds), dim=1)
114
+
115
+ if audio is not None:
116
+ audio_transcribed = embed_audio(audio)
117
+ audio_embeds = embed_text(audio_transcribed)
118
+ input_embeds = torch.cat((input_embeds, audio_embeds), dim=1)
119
+
120
+ if text is not None:
121
+ query = text
122
+ text_embeds = embed_text(text)
123
+ input_embeds = torch.cat((input_embeds, text_embeds), dim=1)
124
+
125
+ question = tokenizer(" Question: " + query, return_tensors="pt", return_attention_mask=False)
126
+ question_embeds = phi2_model_peft.get_input_embeddings()(question.input_ids)
127
+
128
+ input_embeds = torch.cat((input_embeds, question_embeds), dim=1)
129
+
130
+ answer = tokenizer(" Answer: ", return_tensors="pt", return_attention_mask=False)
131
+ answer_embeds = phi2_model_peft.get_input_embeddings()(answer.input_ids)
132
+
133
+ input_embeds = torch.cat((input_embeds, answer_embeds), dim=1)
134
+ result = phi2_model_peft.generate(inputs_embeds=input_embeds, bos_token_id = tokenizer.bos_token_id)
135
+ final_ans = tokenizer.batch_decode(result)[0]
136
+ final_ans = final_ans.split(tokenizer.eos_token)
137
+
138
+ if final_ans[0] == '':
139
+ return final_ans[1]
140
+ else:
141
+ return final_ans[0]
142
+
143
+
144
+ demo = gr.Interface(
145
+ fn = inference,
146
+ inputs = [
147
+ gr.Image(label="Image Input"),
148
+ gr.Audio(label="Audio Input", sources=["microphone", "upload"], type="filepath"),
149
+ gr.Textbox(label="Text Input"),
150
+ ],
151
+ outputs = [
152
+ gr.Textbox(label='Answer'),
153
+ ],
154
+ )
155
+
156
+ demo.launch()