MarieAngeA13 commited on
Commit
6162fcf
1 Parent(s): 8d6722d

Upload sentiment_analysis.py

Browse files
Files changed (1) hide show
  1. sentiment_analysis.py +602 -0
sentiment_analysis.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Sentiment_analysis.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1EHgMQQJzwbNja0JVMM2DVvrVTMHIS3Vg
8
+ """
9
+
10
+ !pip install transformers
11
+
12
+ import pandas as pd
13
+ from wordcloud import WordCloud
14
+ import seaborn as sns
15
+ import re
16
+ import string
17
+ from collections import Counter, defaultdict
18
+
19
+ from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
20
+
21
+ import plotly.express as px
22
+ from plotly.subplots import make_subplots
23
+ import plotly.graph_objects as go
24
+ from plotly.offline import plot
25
+
26
+ import matplotlib.gridspec as gridspec
27
+ from matplotlib.ticker import MaxNLocator
28
+ import matplotlib.patches as mpatches
29
+ import matplotlib.pyplot as plt
30
+ import warnings
31
+ warnings.filterwarnings('ignore')
32
+ import nltk
33
+ nltk.download('stopwords')
34
+ from nltk.corpus import stopwords
35
+ stopWords_nltk = set(stopwords.words('english'))
36
+
37
+
38
+ import re
39
+ from typing import Union, List
40
+
41
+ class CleanText():
42
+ """ clearing text except digits () . , word character """
43
+
44
+ def __init__(self, clean_pattern = r"[^A-ZĞÜŞİÖÇIa-zğüı'şöç0-9.\"',()]"):
45
+ self.clean_pattern =clean_pattern
46
+
47
+ def __call__(self, text: Union[str, list]) -> str:
48
+
49
+ if isinstance(text, str):
50
+ docs = [[text]]
51
+
52
+ if isinstance(text, list):
53
+ docs = text
54
+
55
+ text = [[re.sub(self.clean_pattern, " ", sent) for sent in sents] for sents in docs]
56
+
57
+ # Join the list of lists into a single string
58
+ text = ' '.join([' '.join(sents) for sents in text])
59
+
60
+ return text
61
+
62
+ def remove_emoji(data):
63
+ emoj = re.compile("["
64
+ u"\U0001F600-\U0001F64F" # emoticons
65
+ u"\U0001F300-\U0001F5FF" # symbols & pictographs
66
+ u"\U0001F680-\U0001F6FF" # transport & map symbols
67
+ u"\U0001F1E0-\U0001F1FF" # flags (iOS)
68
+ u"\U00002500-\U00002BEF"
69
+ u"\U00002702-\U000027B0"
70
+ u"\U00002702-\U000027B0"
71
+ u"\U000024C2-\U0001F251"
72
+ u"\U0001f926-\U0001f937"
73
+ u"\U00010000-\U0010ffff"
74
+ u"\u2640-\u2642"
75
+ u"\u2600-\u2B55"
76
+ u"\u200d"
77
+ u"\u23cf"
78
+ u"\u23e9"
79
+ u"\u231a"
80
+ u"\ufe0f" # dingbats
81
+ u"\u3030"
82
+ "]+", re.UNICODE)
83
+ return re.sub(emoj, '', data)
84
+
85
+ def tokenize(text):
86
+ """ basic tokenize method with word character, non word character and digits """
87
+ text = re.sub(r" +", " ", str(text))
88
+ text = re.split(r"(\d+|[a-zA-ZğüşıöçĞÜŞİÖÇ]+|\W)", text)
89
+ text = list(filter(lambda x: x != '' and x != ' ', text))
90
+ sent_tokenized = ' '.join(text)
91
+ return sent_tokenized
92
+
93
+ regex = re.compile('[%s]' % re.escape(string.punctuation))
94
+
95
+ def remove_punct(text):
96
+ text = regex.sub(" ", text)
97
+ return text
98
+
99
+ clean = CleanText()
100
+
101
+ def label_encode(x):
102
+ if x == 1 or x == 2:
103
+ return 0
104
+ if x == 3:
105
+ return 1
106
+ if x == 5 or x == 4:
107
+ return 2
108
+
109
+ def label2name(x):
110
+ if x == 0:
111
+ return "Negative"
112
+ if x == 1:
113
+ return "Neutral"
114
+ if x == 2:
115
+ return "Positive"
116
+
117
+ from google.colab import files
118
+ uploaded = files.upload()
119
+ df = pd.read_csv('tripadvisor_hotel_reviews.csv')
120
+
121
+ print("df.columns: ", df.columns)
122
+
123
+ fig = px.histogram(df,
124
+ x = 'Rating',
125
+ title = 'Histogram of Review Rating',
126
+ template = 'ggplot2',
127
+ color = 'Rating',
128
+ color_discrete_sequence= px.colors.sequential.Blues_r,
129
+ opacity = 0.8,
130
+ height = 525,
131
+ width = 835,
132
+ )
133
+
134
+ fig.update_yaxes(title='Count')
135
+ fig.show()
136
+
137
+ df.info()
138
+
139
+ df["label"] = df["Rating"].apply(lambda x: label_encode(x))
140
+ df["label_name"] = df["label"].apply(lambda x: label2name(x))
141
+
142
+ df["Review"] = df["Review"].apply(lambda x: remove_punct(clean(remove_emoji(x).lower())[0][0]))
143
+
144
+ df.head()
145
+
146
+ fig = make_subplots(rows=1, cols=2, specs=[[{"type": "pie"}, {"type": "bar"}]])
147
+ colors = ['gold', 'mediumturquoise', 'lightgreen'] # darkorange
148
+ fig.add_trace(go.Pie(labels=df.label_name.value_counts().index,
149
+ values=df.label.value_counts().values), 1, 1)
150
+
151
+ fig.update_traces(hoverinfo='label+percent', textfont_size=20,
152
+ marker=dict(colors=colors, line=dict(color='#000000', width=2)))
153
+
154
+ fig.add_trace(go.Bar(x=df.label_name.value_counts().index, y=df.label.value_counts().values, marker_color = colors), 1,2)
155
+
156
+ fig.show()
157
+
158
+ import pandas as pd
159
+ import numpy as np
160
+ import os
161
+ import random
162
+ from pathlib import Path
163
+ import json
164
+
165
+ import torch
166
+ from tqdm.notebook import tqdm
167
+
168
+ from transformers import BertTokenizer
169
+ from torch.utils.data import TensorDataset
170
+
171
+ from transformers import BertForSequenceClassification
172
+
173
+ class Config():
174
+ seed_val = 17
175
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
176
+ epochs = 5
177
+ batch_size = 6
178
+ seq_length = 512
179
+ lr = 2e-5
180
+ eps = 1e-8
181
+ pretrained_model = 'bert-base-uncased'
182
+ test_size=0.15
183
+ random_state=42
184
+ add_special_tokens=True
185
+ return_attention_mask=True
186
+ pad_to_max_length=True
187
+ do_lower_case=False
188
+ return_tensors='pt'
189
+ config = Config()
190
+
191
+ # params will be saved after training
192
+ params = {"seed_val": config.seed_val,
193
+ "device":str(config.device),
194
+ "epochs":config.epochs,
195
+ "batch_size":config.batch_size,
196
+ "seq_length":config.seq_length,
197
+ "lr":config.lr,
198
+ "eps":config.eps,
199
+ "pretrained_model": config.pretrained_model,
200
+ "test_size":config.test_size,
201
+ "random_state":config.random_state,
202
+ "add_special_tokens":config.add_special_tokens,
203
+ "return_attention_mask":config.return_attention_mask,
204
+ "pad_to_max_length":config.pad_to_max_length,
205
+ "do_lower_case":config.do_lower_case,
206
+ "return_tensors":config.return_tensors,
207
+ }
208
+
209
+ import random
210
+
211
+ device = config.device
212
+
213
+ random.seed(config.seed_val)
214
+ np.random.seed(config.seed_val)
215
+ torch.manual_seed(config.seed_val)
216
+ torch.cuda.manual_seed_all(config.seed_val)
217
+
218
+ df.head()
219
+
220
+ from sklearn.model_selection import train_test_split
221
+
222
+ train_df_, val_df = train_test_split(df,
223
+ test_size=0.10,
224
+ random_state=config.random_state,
225
+ stratify=df.label.values)
226
+
227
+ train_df_.head()
228
+
229
+ train_df, test_df = train_test_split(train_df_,
230
+ test_size=0.10,
231
+ random_state=42,
232
+ stratify=train_df_.label.values)
233
+
234
+ print(len(train_df['label'].unique()))
235
+ print(train_df.shape)
236
+
237
+ print(len(val_df['label'].unique()))
238
+ print(val_df.shape)
239
+
240
+ print(len(test_df['label'].unique()))
241
+ print(test_df.shape)
242
+
243
+ tokenizer = BertTokenizer.from_pretrained(config.pretrained_model,
244
+ do_lower_case=config.do_lower_case)
245
+
246
+ encoded_data_train = tokenizer.batch_encode_plus(
247
+ train_df.Review.values,
248
+ add_special_tokens=config.add_special_tokens,
249
+ return_attention_mask=config.return_attention_mask,
250
+ pad_to_max_length=config.pad_to_max_length,
251
+ max_length=config.seq_length,
252
+ return_tensors=config.return_tensors
253
+ )
254
+ encoded_data_val = tokenizer.batch_encode_plus(
255
+ val_df.Review.values,
256
+ add_special_tokens=config.add_special_tokens,
257
+ return_attention_mask=config.return_attention_mask,
258
+ pad_to_max_length=config.pad_to_max_length,
259
+ max_length=config.seq_length,
260
+ return_tensors=config.return_tensors
261
+ )
262
+
263
+ input_ids_train = encoded_data_train['input_ids']
264
+ attention_masks_train = encoded_data_train['attention_mask']
265
+ labels_train = torch.tensor(train_df.label.values)
266
+
267
+ input_ids_val = encoded_data_val['input_ids']
268
+ attention_masks_val = encoded_data_val['attention_mask']
269
+ labels_val = torch.tensor(val_df.label.values)
270
+
271
+ dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)
272
+ dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)
273
+
274
+ model = BertForSequenceClassification.from_pretrained(config.pretrained_model,
275
+ num_labels=3,
276
+ output_attentions=False,
277
+ output_hidden_states=False)
278
+
279
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
280
+
281
+ dataloader_train = DataLoader(dataset_train,
282
+ sampler=RandomSampler(dataset_train),
283
+ batch_size=config.batch_size)
284
+
285
+ dataloader_validation = DataLoader(dataset_val,
286
+ sampler=SequentialSampler(dataset_val),
287
+ batch_size=config.batch_size)
288
+
289
+ from transformers import AdamW, get_linear_schedule_with_warmup
290
+
291
+ optimizer = AdamW(model.parameters(),
292
+ lr=config.lr,
293
+ eps=config.eps)
294
+
295
+
296
+ scheduler = get_linear_schedule_with_warmup(optimizer,
297
+ num_warmup_steps=0,
298
+ num_training_steps=len(dataloader_train)*config.epochs)
299
+
300
+ from sklearn.metrics import f1_score
301
+
302
+ def f1_score_func(preds, labels):
303
+ preds_flat = np.argmax(preds, axis=1).flatten()
304
+ labels_flat = labels.flatten()
305
+ return f1_score(labels_flat, preds_flat, average='weighted')
306
+
307
+ def accuracy_per_class(preds, labels, label_dict):
308
+ label_dict_inverse = {v: k for k, v in label_dict.items()}
309
+
310
+ preds_flat = np.argmax(preds, axis=1).flatten()
311
+ labels_flat = labels.flatten()
312
+
313
+ for label in np.unique(labels_flat):
314
+ y_preds = preds_flat[labels_flat==label]
315
+ y_true = labels_flat[labels_flat==label]
316
+ print(f'Class: {label_dict_inverse[label]}')
317
+ print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\n')
318
+
319
+ def evaluate(dataloader_val):
320
+
321
+ model.eval()
322
+
323
+ loss_val_total = 0
324
+ predictions, true_vals = [], []
325
+
326
+ for batch in dataloader_val:
327
+
328
+ batch = tuple(b.to(config.device) for b in batch)
329
+
330
+ inputs = {'input_ids': batch[0],
331
+ 'attention_mask': batch[1],
332
+ 'labels': batch[2],
333
+ }
334
+
335
+ with torch.no_grad():
336
+ outputs = model(**inputs)
337
+
338
+ loss = outputs[0]
339
+ logits = outputs[1]
340
+ loss_val_total += loss.item()
341
+
342
+ logits = logits.detach().cpu().numpy()
343
+ label_ids = inputs['labels'].cpu().numpy()
344
+ predictions.append(logits)
345
+ true_vals.append(label_ids)
346
+
347
+ # calculate avareage val loss
348
+ loss_val_avg = loss_val_total/len(dataloader_val)
349
+
350
+ predictions = np.concatenate(predictions, axis=0)
351
+ true_vals = np.concatenate(true_vals, axis=0)
352
+
353
+ return loss_val_avg, predictions, true_vals
354
+
355
+ config.device
356
+
357
+ model.to(config.device)
358
+
359
+ for epoch in tqdm(range(1, config.epochs+1)):
360
+
361
+ model.train()
362
+
363
+ loss_train_total = 0
364
+ # allows you to see the progress of the training
365
+ progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
366
+
367
+ for batch in progress_bar:
368
+
369
+ model.zero_grad()
370
+
371
+ batch = tuple(b.to(config.device) for b in batch)
372
+
373
+ inputs = {'input_ids': batch[0],
374
+ 'attention_mask': batch[1],
375
+ 'labels': batch[2],
376
+ }
377
+
378
+ outputs = model(**inputs)
379
+
380
+ loss = outputs[0]
381
+ loss_train_total += loss.item()
382
+ loss.backward()
383
+
384
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
385
+
386
+ optimizer.step()
387
+ scheduler.step()
388
+ progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})
389
+
390
+
391
+ torch.save(model.state_dict(), f'_BERT_epoch_{epoch}.model')
392
+
393
+ tqdm.write(f'\nEpoch {epoch}')
394
+
395
+ loss_train_avg = loss_train_total/len(dataloader_train)
396
+ tqdm.write(f'Training loss: {loss_train_avg}')
397
+
398
+ val_loss, predictions, true_vals = evaluate(dataloader_validation)
399
+ val_f1 = f1_score_func(predictions, true_vals)
400
+ tqdm.write(f'Validation loss: {val_loss}')
401
+
402
+ tqdm.write(f'F1 Score (Weighted): {val_f1}');
403
+ # save model params and other configs
404
+ with Path('params.json').open("w") as f:
405
+ json.dump(params, f, ensure_ascii=False, indent=4)
406
+
407
+ model.load_state_dict(torch.load(f'./_BERT_epoch_3.model', map_location=torch.device('cpu')))
408
+
409
+ from sklearn.metrics import classification_report
410
+
411
+ preds_flat = np.argmax(predictions, axis=1).flatten()
412
+ print(classification_report(preds_flat, true_vals))
413
+
414
+ pred_final = []
415
+
416
+ for i, row in tqdm(val_df.iterrows(), total=val_df.shape[0]):
417
+ predictions = []
418
+
419
+ review = row["Review"]
420
+ encoded_data_test_single = tokenizer.batch_encode_plus(
421
+ [review],
422
+ add_special_tokens=config.add_special_tokens,
423
+ return_attention_mask=config.return_attention_mask,
424
+ pad_to_max_length=config.pad_to_max_length,
425
+ max_length=config.seq_length,
426
+ return_tensors=config.return_tensors
427
+ )
428
+ input_ids_test = encoded_data_test_single['input_ids']
429
+ attention_masks_test = encoded_data_test_single['attention_mask']
430
+
431
+
432
+ inputs = {'input_ids': input_ids_test.to(device),
433
+ 'attention_mask':attention_masks_test.to(device),
434
+ }
435
+
436
+ with torch.no_grad():
437
+ outputs = model(**inputs)
438
+
439
+ logits = outputs[0]
440
+ logits = logits.detach().cpu().numpy()
441
+ predictions.append(logits)
442
+ predictions = np.concatenate(predictions, axis=0)
443
+ pred_final.append(np.argmax(predictions, axis=1).flatten()[0])
444
+
445
+ val_df["pred"] = pred_final
446
+ # Add control column for easier wrong and right predictions
447
+ control = val_df.pred.values == val_df.label.values
448
+ val_df["control"] = control
449
+ # filtering false predictions
450
+ val_df = val_df[val_df.control == False]
451
+
452
+
453
+
454
+ name2label = {"Negative":0,
455
+ "Neutral":1,
456
+ "Positive":2
457
+ }
458
+ label2name = {v: k for k, v in name2label.items()}
459
+
460
+ val_df["pred_name"] = val_df.pred.apply(lambda x: label2name.get(x))
461
+ from sklearn.metrics import confusion_matrix
462
+
463
+ # We create a confusion matrix to better observe the classes that the model confuses.
464
+ pred_name_values = val_df.pred_name.values
465
+ label_values = val_df.label_name.values
466
+ confmat = confusion_matrix(label_values, pred_name_values, labels=list(name2label.keys()))
467
+
468
+ confmat
469
+
470
+ df_confusion_val = pd.crosstab(label_values, pred_name_values)
471
+ df_confusion_val
472
+
473
+ df_confusion_val.to_csv("val_df_confusion.csv")
474
+
475
+ test_df.head()
476
+
477
+ encoded_data_test = tokenizer.batch_encode_plus(
478
+ test_df.Review.values,
479
+ add_special_tokens=config.add_special_tokens,
480
+ return_attention_mask=config.return_attention_mask,
481
+ pad_to_max_length=config.pad_to_max_length,
482
+ max_length=config.seq_length,
483
+ return_tensors=config.return_tensors
484
+ )
485
+ input_ids_test = encoded_data_test['input_ids']
486
+ attention_masks_test = encoded_data_test['attention_mask']
487
+ labels_test = torch.tensor(test_df.label.values)
488
+
489
+ model = BertForSequenceClassification.from_pretrained(config.pretrained_model,
490
+ num_labels=3,
491
+ output_attentions=False,
492
+ output_hidden_states=False)
493
+
494
+ model.to(config.device)
495
+
496
+ model.load_state_dict(torch.load(f'./_BERT_epoch_3.model', map_location=torch.device('cpu')))
497
+
498
+ _, predictions_test, true_vals_test = evaluate(dataloader_validation)
499
+ # accuracy_per_class(predictions, true_vals, intent2label)
500
+
501
+ def predict_sentiment(text):
502
+ # Prétraitement du texte
503
+ encoded_text = tokenizer.encode_plus(
504
+ text,
505
+ add_special_tokens=config.add_special_tokens,
506
+ return_attention_mask=config.return_attention_mask,
507
+ pad_to_max_length=config.pad_to_max_length,
508
+ max_length=config.seq_length,
509
+ return_tensors=config.return_tensors
510
+ )
511
+
512
+ # Convertir les entrées en tenseurs et les déplacer vers le bon appareil
513
+ input_ids = encoded_text['input_ids'].to(config.device)
514
+ attention_mask = encoded_text['attention_mask'].to(config.device)
515
+
516
+ # Mettre le modèle en mode d'évaluation et obtenir les prédictions
517
+ model.eval()
518
+ with torch.no_grad():
519
+ outputs = model(input_ids, attention_mask)
520
+
521
+ # Obtenir la prédiction du modèle
522
+ logits = outputs[0]
523
+ logits = logits.detach().cpu().numpy()
524
+
525
+ # Extraire la classe avec la probabilité la plus élevée
526
+ pred = np.argmax(logits, axis=1).flatten()[0]
527
+
528
+ # Convertir le label numérique en son nom correspondant
529
+ pred_name = label2name.get(pred)
530
+
531
+ return pred_name
532
+
533
+ text = "Your text here"
534
+ prediction = predict_sentiment(text)
535
+ print(f"The sentiment of the text is: {prediction}")
536
+
537
+ from sklearn.metrics import classification_report
538
+
539
+ preds_flat_test = np.argmax(predictions_test, axis=1).flatten()
540
+ print(classification_report(preds_flat_test, true_vals_test))
541
+
542
+ pred_final = []
543
+
544
+ for i, row in tqdm(test_df.iterrows(), total=test_df.shape[0]):
545
+ predictions = []
546
+
547
+ review = row["Review"]
548
+ encoded_data_test_single = tokenizer.batch_encode_plus(
549
+ [review],
550
+ add_special_tokens=config.add_special_tokens,
551
+ return_attention_mask=config.return_attention_mask,
552
+ pad_to_max_length=config.pad_to_max_length,
553
+ max_length=config.seq_length,
554
+ return_tensors=config.return_tensors
555
+ )
556
+ input_ids_test = encoded_data_test_single['input_ids']
557
+ attention_masks_test = encoded_data_test_single['attention_mask']
558
+
559
+ inputs = {'input_ids': input_ids_test.to(device),
560
+ 'attention_mask':attention_masks_test.to(device),
561
+ }
562
+
563
+ with torch.no_grad():
564
+ outputs = model(**inputs)
565
+
566
+ logits = outputs[0]
567
+ logits = logits.detach().cpu().numpy()
568
+ predictions.append(logits)
569
+ predictions = np.concatenate(predictions, axis=0)
570
+ pred_final.append(np.argmax(predictions, axis=1).flatten()[0])
571
+
572
+ # add pred into test
573
+ test_df["pred"] = pred_final
574
+ # Add control column for easier wrong and right predictions
575
+ control = test_df.pred.values == test_df.label.values
576
+ test_df["control"] = control
577
+ # filtering false predictions
578
+ test_df = test_df[test_df.control == False]
579
+ test_df["pred_name"] = test_df.pred.apply(lambda x: label2name.get(x))
580
+
581
+ from sklearn.metrics import confusion_matrix
582
+
583
+ # We create a confusion matrix to better observe the classes that the model confuses.
584
+ pred_name_values = test_df.pred_name.values
585
+ label_values = test_df.label_name.values
586
+ confmat = confusion_matrix(label_values, pred_name_values, labels=list(name2label.keys()))
587
+ confmat
588
+
589
+ df_confusion_test = pd.crosstab(label_values, pred_name_values)
590
+ df_confusion_test
591
+
592
+ import matplotlib.pyplot as plt
593
+ import seaborn as sns
594
+
595
+ # Supposons que 'confmat' est votre matrice de confusion
596
+
597
+ fig, ax = plt.subplots(figsize=(10,10)) # changez la taille selon vos besoins
598
+ sns.heatmap(confmat, annot=True, fmt='d',
599
+ xticklabels=name2label.keys(), yticklabels=name2label.keys())
600
+ plt.ylabel('Vraies valeurs')
601
+ plt.xlabel('Prédictions')
602
+ plt.show()