edia_full_en / modules /module_WordExplorer.py
nanom's picture
First commit
e8aad19
raw
history blame
6.82 kB
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from numpy.linalg import norm
import matplotlib as mpl
mpl.use('Agg')
from typing import List, Dict, Tuple
class WordToPlot:
def __init__(
self,
word: str,
color: str,
bias_space: int,
alpha: float
) -> None:
self.word = word
self.color = color
self.bias_space = bias_space
self.alpha = alpha
class WordExplorer:
def __init__(
self,
embedding, # Embedding Class instance
errorManager # ErrorManager class instance
) -> None:
self.embedding = embedding
self.errorManager = errorManager
def __errorChecking(
self,
word: str
) -> str:
out_msj = ""
if not word:
out_msj = ['EMBEDDING_NO_WORD_PROVIDED']
else:
if word not in self.embedding:
out_msj = ['EMBEDDING_WORD_OOV', word]
return self.errorManager.process(out_msj)
def check_oov(
self,
wordlists: List[List[str]]
) -> str:
for wordlist in wordlists:
for word in wordlist:
msg = self.__errorChecking(word)
if msg:
return msg
return None
def get_neighbors(
self,
word: str,
n_neighbors: int,
nn_method: str
) -> List[str]:
err = self.check_oov([[word]])
if err:
raise Exception(err)
return self.embedding.getNearestNeighbors(word, n_neighbors, nn_method)
def get_df(
self,
words_embedded: np.ndarray,
processed_word_list: List[str]
) -> pd.DataFrame:
df = pd.DataFrame(words_embedded)
df['word'] = [wtp.word for wtp in processed_word_list]
df['color'] = [wtp.color for wtp in processed_word_list]
df['alpha'] = [wtp.alpha for wtp in processed_word_list]
df['word_bias_space'] = [wtp.bias_space for wtp in processed_word_list]
return df
def get_plot(
self,
data: pd.DataFrame,
processed_word_list: List[str],
words_embedded: np.ndarray,
color_dict: Dict,
n_neighbors: int,
n_alpha: float,
fontsize: int=18,
figsize: Tuple[int, int]=(20, 15)
):
fig, ax = plt.subplots(figsize=figsize)
sns.scatterplot(
data=data[data['alpha'] == 1],
x=0,
y=1,
style='word_bias_space',
hue='word_bias_space',
ax=ax,
palette=color_dict
)
if n_neighbors > 0:
sns.scatterplot(
data=data[data['alpha'] != 1],
x=0,
y=1,
style='color',
hue='word_bias_space',
ax=ax,
alpha=n_alpha,
legend=False,
palette=color_dict
)
for i, wtp in enumerate(processed_word_list):
x, y = words_embedded[i, :]
ax.annotate(
wtp.word,
xy=(x, y),
xytext=(5, 2),
color=wtp.color,
textcoords='offset points',
ha='right',
va='bottom',
size=fontsize,
alpha=wtp.alpha
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel('')
ax.set_ylabel('')
fig.tight_layout()
return fig
def plot_projections_2d(
self,
wordlist_0: List[str],
wordlist_1: List[str]=[],
wordlist_2: List[str]=[],
wordlist_3: List[str]=[],
wordlist_4: List[str]=[],
**kwargs
):
# convertirlas a vector
choices = [0, 1, 2, 3, 4]
wordlist_choice = [
wordlist_0,
wordlist_1,
wordlist_2,
wordlist_3,
wordlist_4
]
err = self.check_oov(wordlist_choice)
if err:
raise Exception(err)
color_dict = {
0: kwargs.get('color_wordlist_0', '#000000'),
1: kwargs.get('color_wordlist_1', '#1f78b4'),
2: kwargs.get('color_wordlist_2', '#33a02c'),
3: kwargs.get('color_wordlist_3', '#e31a1c'),
4: kwargs.get('color_wordlist_4', '#6a3d9a')
}
n_neighbors = kwargs.get('n_neighbors', 0)
n_alpha = kwargs.get('n_alpha', 0.3)
processed_word_list = []
for word_list_to_process, color in zip(wordlist_choice, choices):
for word in word_list_to_process:
processed_word_list.append(
WordToPlot(word, color_dict[color], color, 1)
)
if n_neighbors > 0:
neighbors = self.get_neighbors(
word,
n_neighbors=n_neighbors,
nn_method=kwargs.get('nn_method', 'sklearn')
)
for n in neighbors:
if n not in [wtp.word for wtp in processed_word_list]:
processed_word_list.append(
WordToPlot(n, color_dict[color], color, n_alpha)
)
if not processed_word_list:
raise Exception('Only empty lists were passed')
words_embedded = np.array(
[self.embedding.getPCA(wtp.word) for wtp in processed_word_list]
)
data = self.get_df(
words_embedded,
processed_word_list
)
fig = self.get_plot(
data,
processed_word_list,
words_embedded,
color_dict,
n_neighbors,
n_alpha,
kwargs.get('fontsize', 18),
kwargs.get('figsize', (20, 15))
)
plt.show()
return fig
# ToDo: No hay usos de este método. ¿Borrar?
def doesnt_match(
self,
wordlist: List[str]
) -> str:
err = self.check_oov([wordlist])
if err:
raise Exception(err)
words_emb = np.array([self.embedding.getEmbedding(word)
for word in wordlist])
mean_vec = np.mean(words_emb, axis=0)
doesnt_match = ""
farthest_emb = 1.0
for word in wordlist:
word_emb = self.embedding.getEmbedding(word)
cos_sim = np.dot(mean_vec, word_emb) / \
(norm(mean_vec)*norm(word_emb))
if cos_sim <= farthest_emb:
farthest_emb = cos_sim
doesnt_match = word
return doesnt_match