SahilJ2 commited on
Commit
39b01b9
1 Parent(s): e7c8703

First commit

Browse files
Files changed (2) hide show
  1. app.py +252 -0
  2. classifier.pth +3 -0
app.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import pandas as pd
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ from transformers import BertTokenizer, AutoModel
8
+ from torch.utils.data import Dataset, DataLoader, random_split
9
+ from sklearn.model_selection import train_test_split
10
+ from typing import List
11
+ from dataclasses import dataclass
12
+ import gradio as gr
13
+ import torch, re
14
+ import numpy as np
15
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration, ViTImageProcessor, BertTokenizer, BlipProcessor, BlipForQuestionAnswering, AutoProcessor, AutoModelForCausalLM, DonutProcessor, VisionEncoderDecoderModel, Pix2StructProcessor, Pix2StructForConditionalGeneration, AutoModelForSeq2SeqLM
16
+
17
+ import librosa
18
+ from PIL import Image
19
+ from torch.nn.utils import rnn
20
+ from gtts import gTTS
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+
24
+ class LabelClassifier(nn.Module):
25
+ def __init__(self):
26
+ super(LabelClassifier, self).__init__()
27
+ self.text_encoder = AutoModel.from_pretrained('bert-base-uncased')
28
+ self.image_encoder = AutoModel.from_pretrained('microsoft/swin-tiny-patch4-window7-224')
29
+ self.intermediate_dim = 128
30
+ self.fusion = nn.Sequential(
31
+ nn.Linear(self.text_encoder.config.hidden_size + self.image_encoder.config.hidden_size, self.intermediate_dim),
32
+ nn.ReLU(),
33
+ nn.Dropout(0.5),
34
+ )
35
+ self.classifier = nn.Linear(self.intermediate_dim, 6) # Concatenating BERT output and Swin Transformer output
36
+
37
+ self.criterion = nn.CrossEntropyLoss()
38
+
39
+
40
+ def forward(self,
41
+ input_ids: torch.LongTensor,pixel_values: torch.FloatTensor, attention_mask: torch.LongTensor = None, token_type_ids: torch.LongTensor = None, labels: torch.LongTensor = None):
42
+
43
+ encoded_text = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
44
+ encoded_image = self.image_encoder(pixel_values=pixel_values)
45
+
46
+ # print(encoded_text['last_hidden_state'].shape)
47
+ # print(encoded_image['last_hidden_state'].shape)
48
+
49
+ fused_state = self.fusion(torch.cat((encoded_text['pooler_output'], encoded_image['pooler_output']), dim=1))
50
+
51
+
52
+ # Pass through the classifier
53
+ logits = self.classifier(fused_state)
54
+
55
+ out = {"logits": logits}
56
+
57
+ if labels is not None:
58
+ loss = self.criterion(logits, labels)
59
+ out["loss"] = loss
60
+
61
+
62
+ return out
63
+
64
+ model = LabelClassifier().to(device)
65
+ model.load_state_dict(torch.load('classifier.pth'))
66
+
67
+
68
+
69
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
70
+ processor = ViTImageProcessor.from_pretrained('microsoft/swin-tiny-patch4-window7-224')
71
+
72
+
73
+ # Load the Whisper model in Hugging Face format:
74
+ # processor2 = WhisperProcessor.from_pretrained("openai/whisper-medium.en")
75
+ # model2 = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium.en")
76
+
77
+
78
+
79
+ def m1(que, image):
80
+ processor3 = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
81
+ model3 = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large").to("cuda")
82
+
83
+ inputs = processor3(image, que, return_tensors="pt").to("cuda")
84
+
85
+ out = model3.generate(**inputs)
86
+ return processor3.decode(out[0], skip_special_tokens=True)
87
+
88
+ def m2(que, image):
89
+ processor3 = AutoProcessor.from_pretrained("microsoft/git-large-textvqa")
90
+ model3 = AutoModelForCausalLM.from_pretrained("microsoft/git-large-textvqa")
91
+
92
+ pixel_values = processor3(images=image, return_tensors="pt").pixel_values
93
+
94
+ input_ids = processor3(text=que, add_special_tokens=False).input_ids
95
+ input_ids = [processor3.tokenizer.cls_token_id] + input_ids
96
+ input_ids = torch.tensor(input_ids).unsqueeze(0)
97
+
98
+ generated_ids = model3.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
99
+ return processor3.batch_decode(generated_ids, skip_special_tokens=True)
100
+
101
+ def m3(que, image):
102
+ processor3 = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
103
+ model3 = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
104
+
105
+ device = "cuda" if torch.cuda.is_available() else "cpu"
106
+ model3.to(device)
107
+
108
+ prompt = "<s_docvqa><s_question>{que}</s_question><s_answer>"
109
+ decoder_input_ids = processor3.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
110
+
111
+ pixel_values = processor3(image, return_tensors="pt").pixel_values
112
+
113
+ outputs = model3.generate(
114
+ pixel_values.to(device),
115
+ decoder_input_ids=decoder_input_ids.to(device),
116
+ max_length=model3.decoder.config.max_position_embeddings,
117
+ pad_token_id=processor3.tokenizer.pad_token_id,
118
+ eos_token_id=processor3.tokenizer.eos_token_id,
119
+ use_cache=True,
120
+ bad_words_ids=[[processor3.tokenizer.unk_token_id]],
121
+ return_dict_in_generate=True,
122
+ )
123
+
124
+ sequence = processor3.batch_decode(outputs.sequences)[0]
125
+ sequence = sequence.replace(processor3.tokenizer.eos_token, "").replace(processor3.tokenizer.pad_token, "")
126
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
127
+ return processor3.token2json(sequence)['answer']
128
+
129
+ def m4(que, image):
130
+ processor3 = Pix2StructProcessor.from_pretrained('google/matcha-plotqa-v1')
131
+ model3 = Pix2StructForConditionalGeneration.from_pretrained('google/matcha-plotqa-v1')
132
+
133
+ inputs = processor3(images=image, text=que, return_tensors="pt")
134
+ predictions = model3.generate(**inputs, max_new_tokens=512)
135
+ return processor3.decode(predictions[0], skip_special_tokens=True)
136
+
137
+ def m5(que, image):
138
+
139
+ processor3 = AutoProcessor.from_pretrained("google/pix2struct-ocrvqa-large")
140
+ model3 = AutoModelForSeq2SeqLM.from_pretrained("google/pix2struct-ocrvqa-large")
141
+
142
+ inputs = processor3(images=image, text=que, return_tensors="pt").to("cuda")
143
+
144
+ predictions = model3.generate(**inputs)
145
+ return processor3.decode(predictions[0], skip_special_tokens=True)
146
+
147
+ def m6(que, image):
148
+ processor3 = AutoProcessor.from_pretrained("google/pix2struct-infographics-vqa-large")
149
+ model3 = AutoModelForSeq2SeqLM.from_pretrained("google/pix2struct-infographics-vqa-large")
150
+
151
+ inputs = processor3(images=image, text=que, return_tensors="pt").to("cuda")
152
+
153
+ predictions = model3.generate(**inputs)
154
+ return processor3.decode(predictions[0], skip_special_tokens=True)
155
+
156
+
157
+ def predict_answer(category, que, image):
158
+ if category == 0:
159
+ return m1(que, image)
160
+ elif category == 1:
161
+ return m2(que, image)
162
+ elif category == 2:
163
+ return m3(que, image)
164
+ elif category == 3:
165
+ return m4(que, image)
166
+ elif category == 4:
167
+ return m5(que, image)
168
+ else:
169
+ return m6(que, image)
170
+
171
+
172
+
173
+ def transcribe_audio(audio):
174
+ # print(audio)
175
+ processor2 = WhisperProcessor.from_pretrained("openai/whisper-large-v3",language='en')
176
+ model2 = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
177
+
178
+ sampling_rate = audio[0]
179
+ audio_data = audio[1]
180
+
181
+ # print(np.array([audio_data]).shape)
182
+ audio_data_float = np.array(audio_data).astype(np.float32)
183
+ resampled_audio_data = librosa.resample(audio_data_float, orig_sr=sampling_rate, target_sr=16000)
184
+
185
+
186
+ # Use the model and processor to transcribe the audio:
187
+ input_features = processor2(
188
+ resampled_audio_data, sampling_rate=16000, return_tensors="pt"
189
+ ).input_features
190
+
191
+ # Generate token ids
192
+ predicted_ids = model2.generate(input_features)
193
+
194
+ # Decode token ids to text
195
+ transcription = processor2.batch_decode(predicted_ids, skip_special_tokens=True)[0]
196
+
197
+ return transcription
198
+
199
+
200
+ def predict_category(que, input_image):
201
+ # print(type(input_image))
202
+ # print(input_image)
203
+
204
+ encoded_text = tokenizer(
205
+ text=que,
206
+ padding='longest',
207
+ max_length=24,
208
+ truncation=True,
209
+ return_tensors='pt',
210
+ return_token_type_ids=True,
211
+ return_attention_mask=True,
212
+ )
213
+
214
+ encoded_image = processor(input_image, return_tensors='pt').to(device)
215
+
216
+ dict = {
217
+ 'input_ids': encoded_text['input_ids'].to(device),
218
+ 'token_type_ids': encoded_text['token_type_ids'].to(device),
219
+ 'attention_mask': encoded_text['attention_mask'].to(device),
220
+ 'pixel_values': encoded_image['pixel_values'].to(device)
221
+ }
222
+
223
+ output = model(input_ids=dict['input_ids'],token_type_ids=dict['token_type_ids'],attention_mask=dict['attention_mask'],pixel_values=dict['pixel_values'])
224
+
225
+ preds = output["logits"].argmax(axis=-1).cpu().numpy()
226
+
227
+ return preds[0]
228
+
229
+
230
+ def combine(audio, input_image):
231
+ que = transcribe_audio(audio)
232
+ # que = "What is the animal here?"
233
+
234
+ image = Image.fromarray(input_image).convert('RGB')
235
+ category = predict_category(que, image)
236
+
237
+ answer = predict_answer(0, que, image)
238
+
239
+ # print(category)
240
+
241
+ tts = gTTS(answer)
242
+ tts.save('answer.mp3')
243
+ return que, answer, 'answer.mp3'
244
+
245
+
246
+
247
+ # Define the Gradio interface for recording audio and displaying the transcription
248
+ model_interface = gr.Interface(fn=combine, inputs=[gr.Microphone(label="Ask your question"),gr.Image(label="Upload the image")], outputs=[gr.Text(label="Transcribed Question"), gr.Text(label="Answer"), gr.Audio(label="Audio Answer")])
249
+ # image_upload_interface = gr.Interface(fn=upload_image, inputs=gr.Image(label="Upload the image"), outputs="text")
250
+
251
+ # Launch the Gradio interface
252
+ model_interface.launch(debug=True)
classifier.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd3c37a6110305f0641190ac90f5db3e527056f5a9bdbe2c11214256435c62fa
3
+ size 549215152