nanom commited on
Commit
a101a53
Β·
1 Parent(s): 8e9956d

Improvement in the display of the graph axes labels. Generalization of rankSent class. Minor fixes.

Browse files
modules/module_BiasExplorer.py CHANGED
@@ -5,7 +5,7 @@ import seaborn as sns
5
  import matplotlib.pyplot as plt
6
  from sklearn.decomposition import PCA
7
  from typing import List, Dict, Tuple, Optional, Any
8
- from modules.utils import normalize, cosine_similarity, project_params, take_two_sides_extreme_sorted
9
 
10
  __all__ = ['WordBiasExplorer', 'WEBiasExplorer2Spaces', 'WEBiasExplorer4Spaces']
11
 
@@ -371,9 +371,14 @@ class WEBiasExplorer2Spaces(WordBiasExplorer):
371
  plt.xticks(np.arange(-most_extream_projection,
372
  most_extream_projection + axis_projection_step,
373
  axis_projection_step))
374
- xlabel = ('← {} {} {} β†’'.format(self.negative_end,
375
- ' ' * 20,
376
- self.positive_end))
 
 
 
 
 
377
 
378
  plt.xlabel(xlabel)
379
  plt.ylabel('Words')
@@ -515,13 +520,20 @@ class WEBiasExplorer4Spaces(WordBiasExplorer):
515
  for _, row in (projections_df.iterrows()):
516
  ax.annotate(
517
  row['word'], (row['projection_x'], row['projection_y']))
518
- x_label = '← {} {} {} β†’'.format(name_left,
519
- ' ' * 20,
520
- name_right)
521
 
522
- y_label = '← {} {} {} β†’'.format(name_top,
523
- ' ' * 20,
524
- name_bottom)
 
 
 
 
 
 
 
 
 
 
525
 
526
  plt.xlabel(x_label)
527
  ax.xaxis.set_label_position('bottom')
 
5
  import matplotlib.pyplot as plt
6
  from sklearn.decomposition import PCA
7
  from typing import List, Dict, Tuple, Optional, Any
8
+ from modules.utils import normalize, cosine_similarity, project_params, take_two_sides_extreme_sorted, axes_labels_format
9
 
10
  __all__ = ['WordBiasExplorer', 'WEBiasExplorer2Spaces', 'WEBiasExplorer4Spaces']
11
 
 
371
  plt.xticks(np.arange(-most_extream_projection,
372
  most_extream_projection + axis_projection_step,
373
  axis_projection_step))
374
+
375
+
376
+ xlabel = axes_labels_format(
377
+ left=self.negative_end,
378
+ right=self.positive_end,
379
+ sep=' ' * 20,
380
+ word_wrap=3
381
+ )
382
 
383
  plt.xlabel(xlabel)
384
  plt.ylabel('Words')
 
520
  for _, row in (projections_df.iterrows()):
521
  ax.annotate(
522
  row['word'], (row['projection_x'], row['projection_y']))
 
 
 
523
 
524
+
525
+ x_label = axes_labels_format(
526
+ left=name_left,
527
+ right=name_right,
528
+ sep=' ' * 20,
529
+ word_wrap=3
530
+ )
531
+ y_label = axes_labels_format(
532
+ left=name_top,
533
+ right=name_bottom,
534
+ sep=' ' * 20,
535
+ word_wrap=3
536
+ )
537
 
538
  plt.xlabel(x_label)
539
  ax.xaxis.set_label_position('bottom')
modules/module_connection.py CHANGED
@@ -422,11 +422,12 @@ class PhraseBiasExplorerConnector(Connector):
422
  def rank_sentence_options(
423
  self,
424
  sent: str,
425
- word_list: str,
426
  banned_word_list: str,
427
- useArticles: bool,
428
- usePrepositions: bool,
429
- useConjunctions: bool
 
430
  ) -> Tuple:
431
 
432
  sent = " ".join(sent.strip().replace("*"," * ").split())
@@ -435,7 +436,7 @@ class PhraseBiasExplorerConnector(Connector):
435
  if err:
436
  return err, "", ""
437
 
438
- word_list = self.parse_words(word_list)
439
  banned_word_list = self.parse_words(banned_word_list)
440
 
441
  # Save inputs in logs file
@@ -443,16 +444,17 @@ class PhraseBiasExplorerConnector(Connector):
443
  self.logs_file_name,
444
  self.headers,
445
  sent,
446
- word_list
447
  )
448
 
449
  all_plls_scores = self.phrase_bias_explorer.rank(
450
  sent,
451
- word_list,
452
  banned_word_list,
453
- useArticles,
454
- usePrepositions,
455
- useConjunctions
 
456
  )
457
 
458
  all_plls_scores = self.phrase_bias_explorer.Label.compute(all_plls_scores)
 
422
  def rank_sentence_options(
423
  self,
424
  sent: str,
425
+ interest_word_list: str,
426
  banned_word_list: str,
427
+ exclude_articles: bool,
428
+ exclude_prepositions: bool,
429
+ exclude_conjunctions: bool,
430
+ n_predictions: int=5
431
  ) -> Tuple:
432
 
433
  sent = " ".join(sent.strip().replace("*"," * ").split())
 
436
  if err:
437
  return err, "", ""
438
 
439
+ interest_word_list = self.parse_words(interest_word_list)
440
  banned_word_list = self.parse_words(banned_word_list)
441
 
442
  # Save inputs in logs file
 
444
  self.logs_file_name,
445
  self.headers,
446
  sent,
447
+ interest_word_list
448
  )
449
 
450
  all_plls_scores = self.phrase_bias_explorer.rank(
451
  sent,
452
+ interest_word_list,
453
  banned_word_list,
454
+ exclude_articles,
455
+ exclude_prepositions,
456
+ exclude_conjunctions,
457
+ n_predictions
458
  )
459
 
460
  all_plls_scores = self.phrase_bias_explorer.Label.compute(all_plls_scores)
modules/module_rankSents.py CHANGED
@@ -66,13 +66,14 @@ class RankSents:
66
 
67
  return self.errorManager.process(out_msj)
68
 
69
- def getTop5Predictions(
70
  self,
 
71
  sent: str,
72
- banned_wl: List[str],
73
- articles: bool,
74
- prepositions: bool,
75
- conjunctions: bool
76
  ) -> List[str]:
77
 
78
  sent_masked = sent.replace("*", self.tokenizer.mask_token)
@@ -80,7 +81,8 @@ class RankSents:
80
  sent_masked,
81
  add_special_tokens=True,
82
  return_tensors='pt',
83
- return_attention_mask=True, truncation=True
 
84
  )
85
 
86
  tk_position_mask = torch.where(inputs['input_ids'][0] == self.tokenizer.mask_token_id)[0].item()
@@ -94,26 +96,26 @@ class RankSents:
94
  probabilities = outputs[tk_position_mask]
95
  first_tk_id = torch.argsort(probabilities, descending=True)
96
 
97
- top5_tks_pred = []
98
  for tk_id in first_tk_id:
99
  tk_string = self.tokenizer.decode([tk_id])
100
 
101
- tk_is_banned = tk_string in banned_wl
102
  tk_is_punctuation = not tk_string.isalnum()
103
  tk_is_substring = tk_string.startswith("##")
104
  tk_is_special = (tk_string in self.tokenizer.all_special_tokens)
105
 
106
- if articles:
107
  tk_is_article = tk_string in self.articles
108
  else:
109
  tk_is_article = False
110
 
111
- if prepositions:
112
  tk_is_prepositions = tk_string in self.prepositions
113
  else:
114
  tk_is_prepositions = False
115
 
116
- if conjunctions:
117
  tk_is_conjunctions = tk_string in self.conjunctions
118
  else:
119
  tk_is_conjunctions = False
@@ -128,39 +130,41 @@ class RankSents:
128
  tk_is_conjunctions
129
  ])
130
 
131
- if predictions_is_dessire and len(top5_tks_pred) < 5:
132
- top5_tks_pred.append(tk_string)
133
 
134
- elif len(top5_tks_pred) >= 5:
135
  break
136
 
137
- return top5_tks_pred
138
 
139
  def rank(self,
140
  sent: str,
141
- word_list: List[str]=[],
142
  banned_word_list: List[str]=[],
143
- articles: bool=False,
144
- prepositions: bool=False,
145
- conjunctions: bool=False
 
146
  ) -> Dict[str, float]:
147
 
148
  err = self.errorChecking(sent)
149
  if err:
150
  raise Exception(err)
151
 
152
- if not word_list:
153
- word_list = self.getTop5Predictions(
 
154
  sent,
155
  banned_word_list,
156
- articles,
157
- prepositions,
158
- conjunctions
159
  )
160
 
161
  sent_list = []
162
  sent_list2print = []
163
- for word in word_list:
164
  sent_list.append(sent.replace("*", "<"+word+">"))
165
  sent_list2print.append(sent.replace("*", "<"+word+">"))
166
 
 
66
 
67
  return self.errorManager.process(out_msj)
68
 
69
+ def getTopPredictions(
70
  self,
71
+ n: int,
72
  sent: str,
73
+ banned_word_list: List[str],
74
+ exclude_articles: bool,
75
+ exclude_prepositions: bool,
76
+ exclude_conjunctions: bool,
77
  ) -> List[str]:
78
 
79
  sent_masked = sent.replace("*", self.tokenizer.mask_token)
 
81
  sent_masked,
82
  add_special_tokens=True,
83
  return_tensors='pt',
84
+ return_attention_mask=True,
85
+ truncation=True
86
  )
87
 
88
  tk_position_mask = torch.where(inputs['input_ids'][0] == self.tokenizer.mask_token_id)[0].item()
 
96
  probabilities = outputs[tk_position_mask]
97
  first_tk_id = torch.argsort(probabilities, descending=True)
98
 
99
+ top_tks_pred = []
100
  for tk_id in first_tk_id:
101
  tk_string = self.tokenizer.decode([tk_id])
102
 
103
+ tk_is_banned = tk_string in banned_word_list
104
  tk_is_punctuation = not tk_string.isalnum()
105
  tk_is_substring = tk_string.startswith("##")
106
  tk_is_special = (tk_string in self.tokenizer.all_special_tokens)
107
 
108
+ if exclude_articles:
109
  tk_is_article = tk_string in self.articles
110
  else:
111
  tk_is_article = False
112
 
113
+ if exclude_prepositions:
114
  tk_is_prepositions = tk_string in self.prepositions
115
  else:
116
  tk_is_prepositions = False
117
 
118
+ if exclude_conjunctions:
119
  tk_is_conjunctions = tk_string in self.conjunctions
120
  else:
121
  tk_is_conjunctions = False
 
130
  tk_is_conjunctions
131
  ])
132
 
133
+ if predictions_is_dessire and len(top_tks_pred) < n:
134
+ top_tks_pred.append(tk_string)
135
 
136
+ elif len(top_tks_pred) >= n:
137
  break
138
 
139
+ return top_tks_pred
140
 
141
  def rank(self,
142
  sent: str,
143
+ interest_word_list: List[str]=[],
144
  banned_word_list: List[str]=[],
145
+ exclude_articles: bool=False,
146
+ exclude_prepositions: bool=False,
147
+ exclude_conjunctions: bool=False,
148
+ n_predictions: int=5
149
  ) -> Dict[str, float]:
150
 
151
  err = self.errorChecking(sent)
152
  if err:
153
  raise Exception(err)
154
 
155
+ if not interest_word_list:
156
+ interest_word_list = self.getTopPredictions(
157
+ n_predictions,
158
  sent,
159
  banned_word_list,
160
+ exclude_articles,
161
+ exclude_prepositions,
162
+ exclude_conjunctions
163
  )
164
 
165
  sent_list = []
166
  sent_list2print = []
167
+ for word in interest_word_list:
168
  sent_list.append(sent.replace("*", "<"+word+">"))
169
  sent_list2print.append(sent.replace("*", "<"+word+">"))
170
 
modules/utils.py CHANGED
@@ -1,13 +1,15 @@
1
  import numpy as np
2
  import pandas as pd
3
- from datetime import datetime
4
  import pytz
 
 
 
5
 
6
 
7
  class DateLogs:
8
  def __init__(
9
  self,
10
- zone: str="America/Argentina/Cordoba"
11
  ) -> None:
12
 
13
  self.time_zone = pytz.timezone(zone)
@@ -80,4 +82,63 @@ def cosine_similarity(
80
  v_norm = np.linalg.norm(v)
81
  u_norm = np.linalg.norm(u)
82
  similarity = v @ u / (v_norm * u_norm)
83
- return similarity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import pandas as pd
 
3
  import pytz
4
+ from datetime import datetime
5
+ from typing import List
6
+
7
 
8
 
9
  class DateLogs:
10
  def __init__(
11
  self,
12
+ zone: str = "America/Argentina/Cordoba"
13
  ) -> None:
14
 
15
  self.time_zone = pytz.timezone(zone)
 
82
  v_norm = np.linalg.norm(v)
83
  u_norm = np.linalg.norm(u)
84
  similarity = v @ u / (v_norm * u_norm)
85
+ return similarity
86
+
87
+
88
+ def axes_labels_format(
89
+ left: str,
90
+ right: str,
91
+ sep: str,
92
+ word_wrap: int = 4
93
+ ) -> str:
94
+
95
+ def sparse(
96
+ word: str,
97
+ max_len: int
98
+ ) -> str:
99
+
100
+ diff = max_len-len(word)
101
+ rest = diff if diff > 0 else 0
102
+ return word+" "*rest
103
+
104
+ def gen_block(
105
+ list_: List[str],
106
+ n_rows:int,
107
+ n_cols:int
108
+ ) -> List[str]:
109
+
110
+ block = []
111
+ block_row = []
112
+ for r in range(n_rows):
113
+ for c in range(n_cols):
114
+ i = r * n_cols + c
115
+ w = list_[i] if i <= len(list_) - 1 else ""
116
+ block_row.append(w)
117
+ if (i+1) % n_cols == 0:
118
+ block.append(block_row)
119
+ block_row = []
120
+ return block
121
+
122
+ # Transform 'string' to list of string
123
+ l_list = [word.strip() for word in left.split(",") if word.strip() != ""]
124
+ r_list = [word.strip() for word in right.split(",") if word.strip() != ""]
125
+
126
+ # Get longest word, and longest_list
127
+ longest_list = max(len(l_list), len(r_list))
128
+ longest_word = len(max( max(l_list, key=len), max(r_list, key=len)))
129
+
130
+ # Creation of word blocks for each list
131
+ n_rows = (longest_list // word_wrap) if longest_list % word_wrap == 0 else (longest_list // word_wrap) + 1
132
+ n_cols = word_wrap
133
+
134
+ l_block = gen_block(l_list, n_rows, n_cols)
135
+ r_block = gen_block(r_list, n_rows, n_cols)
136
+
137
+ # Transform list of list to sparse string
138
+ labels = ""
139
+ for i,(l,r) in enumerate(zip(l_block, r_block)):
140
+ line = ' '.join([sparse(w, longest_word) for w in l]) + sep + \
141
+ ' '.join([sparse(w, longest_word) for w in r])
142
+ labels += f"← {line} β†’\n" if i==0 else f" {line} \n"
143
+
144
+ return labels