lzun commited on
Commit
9f17786
·
verified ·
1 Parent(s): efbd898

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +227 -0
  3. requirements.txt +5 -0
  4. sum_model.sav +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sum_model.sav filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # for deep learning models
2
+ import torch
3
+ from transformers import pipeline, AutoTokenizer, AutoFeatureExtractor, ViTModel
4
+
5
+ # for utility
6
+ import numpy as np
7
+ import joblib
8
+
9
+ # for app demo
10
+ import gradio as gr
11
+
12
+ # some global variables
13
+ seed = 42
14
+ np.random.seed(seed)
15
+ torch.manual_seed(seed)
16
+ torch.cuda.manual_seed(seed)
17
+
18
+ # Get cpu, gpu or mps device for training.
19
+ device = (
20
+ "cuda"
21
+ if torch.cuda.is_available()
22
+ else "mps"
23
+ if torch.backends.mps.is_available()
24
+ else "cpu"
25
+ )
26
+
27
+ class mm_inference:
28
+
29
+ def __init__(self, text, img1, img2, img3, max_images = 3, incl_text_flag = True, incl_image_flag = True, incl_text_in_img_flag = True):
30
+ path_to_model = 'sum_model.sav'
31
+ self.model = joblib.load(open(path_to_model, 'rb'))
32
+ self.text = text
33
+ self.imgs = [img1, img2, img3]
34
+ self.max_images = max_images
35
+ self.incl_text_flag = incl_text_flag
36
+ self.incl_image_flag = incl_image_flag
37
+ self.text_model_ckpt = 'dccuchile/bert-base-spanish-wwm-uncased'
38
+ self.img_model_ckpt = 'microsoft/swin-base-patch4-window7-224'
39
+
40
+ # text and image pipelines
41
+ self.tokenizer = AutoTokenizer.from_pretrained(self.text_model_ckpt)
42
+ self.img_feature_extractor = AutoFeatureExtractor.from_pretrained(self.img_model_ckpt)
43
+
44
+ # text and image pipeles for feature extraction
45
+ self.text_feature_extractor = pipeline(task = 'feature-extraction', model = 'lzun/spanish-social-media-boxing-text', tokenizer = self.tokenizer, return_tensors = True, device = device)
46
+ self.img_model = ViTModel.from_pretrained('lzun/spanish-social-media-boxing-images')
47
+
48
+ def get_text_embs(self, text):
49
+ '''
50
+ Feature extraction pipeline using no model head.
51
+ This pipeline extracts the hidden states from the base transformer,
52
+ which can be used as features in downstream tasks.
53
+
54
+ last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size))
55
+ — Sequence of hidden-states at the output of the last layer of the model.
56
+
57
+ For the text model, it returns a tensor of shape torch.Size([batch_size, n_tokens, hidden_dim]), or
58
+ e.g., [1, 33, 768] for an input of 31 tokens (plus the [CLF] and [SEP] special tokens).
59
+
60
+ Returns the first element of the sequence, which is related to the
61
+ CLS token.
62
+
63
+ For the multilingual CLIP model, it resturns a single numpy array of size 512. Returns the whole array.
64
+
65
+
66
+ Inputs
67
+ ------
68
+ text: text string to determine embeddings.
69
+
70
+ clf_token: defaults to True, returns the CLS token of the sequence. Farlse returns
71
+ the whole tensor.
72
+
73
+ Returns
74
+ -------
75
+ Torch tensor of size 768.
76
+ '''
77
+
78
+ # get embeddings from text using the pipeline
79
+ text_embs = self.text_feature_extractor(text)
80
+
81
+ # return CLS token (first one of the last layer)
82
+ return(text_embs[0][0])
83
+
84
+ def get_img_embs(self, image_path):
85
+ """
86
+ For the transformer image model, it returns a tensor of shape torch.Size([batch_size, n_tokens, hidden_dim])
87
+ e.g., [1, 197, 768] for an input of 195 tokens (plus the [CLF] and [SEP] special tokens). Returns the first element of the sequence, which is related to the
88
+ CLS token.
89
+
90
+ For the multilingual CLIP model, it resturns a single numpy array of size 512. Returns the whole array.
91
+
92
+
93
+
94
+ Inputs
95
+ ------
96
+ path: path to the image
97
+
98
+ Returns
99
+ -------
100
+ Torch tensor of size 768.
101
+
102
+ """
103
+
104
+ img = io.imread(image_path)
105
+
106
+ # feature extractor
107
+ inputs = self.img_feature_extractor(images=img, return_tensors="pt")
108
+ outputs = self.img_model(**inputs)
109
+ last_hidden_states = outputs.last_hidden_state
110
+
111
+ return(last_hidden_states[0][0])
112
+
113
+ def count_images(self):
114
+ n_imgs = 3
115
+
116
+ for item in self.imgs:
117
+ if item is None:
118
+ n_imgs -= 1
119
+
120
+ return n_imgs
121
+
122
+ def predict(self):
123
+ """
124
+ Fills the JSON file with the available tweet attributes.
125
+
126
+ Parameters
127
+ ----------
128
+ line: Dict
129
+ Dict with each tweet keys and fields.
130
+ """
131
+
132
+ # -------- get data embeddings --------
133
+
134
+ # determine text embeddings
135
+ if self.incl_text_flag:
136
+ text_embs = self.get_text_embs(self.text)
137
+
138
+ # determine image embeddings
139
+ if self.incl_image_flag:
140
+ num_images = self.count_images()
141
+ # case where there are no images available
142
+ if num_images == 0:
143
+ pass
144
+ else:
145
+ # list to save the embeddings for each image
146
+ img_embs = []
147
+ txt_img_embs = []
148
+
149
+ for j in self.imgs:
150
+ # get image path
151
+ img_path = j
152
+
153
+ # get embeddings of current image
154
+ try:
155
+ img_embs.append(self.get_img_embs(img_path))
156
+ except:
157
+ pass
158
+
159
+ # print(f'Num of images: {num_images}')
160
+ # print(f'Num of img embeddings: {len(img_embs)}')
161
+ # print(f'Num of txt-im embeddings: {len(txt_img_embs)}')
162
+
163
+ # -------- infer overall sentiment --------
164
+
165
+ # apply sum fusion
166
+ emb_sum = np.zeros(768)
167
+
168
+ # add the image embeddings
169
+ if self.incl_image_flag:
170
+ if num_images>0:
171
+ for emb in img_embs:
172
+ emb_sum += emb.detach().numpy()
173
+
174
+ # add text embeddings
175
+ if self.incl_text_flag:
176
+ emb_sum += text_embs.detach().numpy()
177
+
178
+ # predict
179
+ sent = int(self.model.predict(emb_sum.reshape(1,-1))[0])
180
+
181
+ print(sent)
182
+
183
+ return sent
184
+
185
+ def main():
186
+
187
+ with gr.Blocks() as demo:
188
+
189
+ gr.Markdown("# Multimodal Spanish COVID-19 Sentiment Polarity Predictor")
190
+ gr.Markdown("## Input text from a social media post (like X or Instagram)")
191
+ text = gr.Textbox(label="Text from publication")
192
+ gr.Markdown("## Input images from a social media post (min 1, max 3)")
193
+ with gr.Row():
194
+ img1 = gr.Image(label="Image #1 from the publication (mandatory)", type="filepath")
195
+ img2 = gr.Image(label="Image #2 from the publication (if available)", type="filepath")
196
+
197
+ with gr.Row():
198
+ img3 = gr.Image(label="Image #3 from the publication (if available)", type="filepath")
199
+ # img4 = gr.Image(label="Image #4 from the publication (if available)", type="filepath")
200
+ pred_btn = gr.Button("Predict Sentiment")
201
+ gr.Markdown("## Predicted output")
202
+ output = gr.Label(label="Sentiment value")
203
+
204
+ def test1(text, img1, img2, img3):
205
+
206
+ print(text)
207
+ print(img1)
208
+ print(img2)
209
+ print(img3)
210
+
211
+ # init inference class
212
+ inferencer = mm_inference(text,
213
+ img1,
214
+ img2,
215
+ img3)
216
+
217
+ # predict and return label
218
+ return inferencer.predict()
219
+
220
+ pred_btn.click(test1, inputs=[text, img1, img2, img3], outputs = output)
221
+
222
+
223
+ demo.launch()
224
+
225
+
226
+ if __name__ == '__main__':
227
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ joblib
4
+ gradio
5
+ numpy
sum_model.sav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6eacba77e320df01885793b55ee1607f254a03af1ced4e31ee659cc82fa0dce2
3
+ size 1767615