edugp commited on
Commit
a81e575
1 Parent(s): a1f93c9

Add script to generate dataset of embeddings and perplexities. Add script to generate t-SNE plot for embedding and perplexity visualization.

Browse files
Files changed (2) hide show
  1. get_embeddings_and_perplexity.py +47 -0
  2. tsne_plot.py +66 -0
get_embeddings_and_perplexity.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import kenlm
3
+ from datasets import load_dataset
4
+ from tqdm import tqdm
5
+ import pandas as pd
6
+ import numpy as np
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+
10
+ TOTAL_SENTENCES = 20000
11
+ def pp(log_score, length):
12
+ return 10.0 ** (-log_score / length)
13
+
14
+
15
+ embedder = "distiluse-base-multilingual-cased-v1"
16
+ embedder_model = SentenceTransformer(embedder)
17
+ embedding_shape = embedder_model.encode(["foo"])[0].shape[0]
18
+ # http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
19
+ model = kenlm.Model("es.arpa.bin")
20
+ mc4 = load_dataset("mc4", "es", streaming=True)
21
+ count = 0
22
+ embeddings = []
23
+ lenghts = []
24
+ perplexities = []
25
+ sentences = []
26
+
27
+ for sample in tqdm(mc4["train"].shuffle(buffer_size=100_000), total=416057992):
28
+ lines = sample["text"].split("\n")
29
+ for line in lines:
30
+ count += 1
31
+ log_score = model.score(line)
32
+ length = len(line.split()) + 1
33
+ embedding = embedder_model.encode([line])[0]
34
+ embeddings.append(embedding.tolist())
35
+ perplexities.append(pp(log_score, length))
36
+ lenghts.append(length)
37
+ sentences.append(line)
38
+ if count == TOTAL_SENTENCES:
39
+ break
40
+ if count == TOTAL_SENTENCES:
41
+ embeddings = np.array(embeddings)
42
+ df = pd.DataFrame({"sentence": sentences, "length": lenghts, "perplexity": perplexities})
43
+ for dim in range(embedding_shape):
44
+ df[f"dim_{dim}"] = embeddings[:, dim]
45
+ df.to_csv("mc4-es-perplexity-sentences.tsv", index=None, sep="\t")
46
+ print("DONE!")
47
+ break
tsne_plot.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ from typing import Any, Optional
4
+
5
+ import bokeh
6
+ import numpy as np
7
+ import pandas as pd
8
+ from bokeh.models import ColumnDataSource, HoverTool
9
+ from bokeh.plotting import figure, output_file, save
10
+ from bokeh.transform import factor_cmap
11
+ from bokeh.palettes import Cividis256 as Pallete
12
+ from bokeh.resources import CDN
13
+ from bokeh.embed import file_html
14
+ from sklearn.manifold import TSNE
15
+
16
+
17
+ logging.basicConfig(level = logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+ SEED = 0
20
+
21
+ def get_tsne_embeddings(embeddings: np.ndarray, perplexity: int=30, n_components: int=2, init: str='pca', n_iter: int=5000, random_state: int=SEED) -> np.ndarray:
22
+ tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state)
23
+ return tsne.fit_transform(embeddings)
24
+
25
+ def draw_interactive_scatter_plot(texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray) -> Any:
26
+ # Normalize values to range between 0-255, to assign a color for each value
27
+ max_value = values.max()
28
+ min_value = values.min()
29
+ values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str)
30
+ values_color_set = sorted(values_color)
31
+
32
+ values_list = values.astype(str).tolist()
33
+ values_set = sorted(values_list)
34
+
35
+ source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, perplexity=values_list))
36
+ hover = HoverTool(tooltips=[('Sentence', '@text{safe}'), ('Perplexity', '@perplexity')])
37
+ p = figure(plot_width=1200, plot_height=1200, tools=[hover], title='Sentences')
38
+ p.circle(
39
+ 'x', 'y', size=10, source=source, fill_color=factor_cmap('perplexity', palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set))
40
+ return p
41
+
42
+ def generate_plot(tsv: str, output_file_name: str, sample: Optional[int]):
43
+ logger.info("Loading dataset in memory")
44
+ df = pd.read_csv(tsv, sep="\t")
45
+ if sample:
46
+ df = df.sample(sample, random_state=SEED)
47
+ logger.info(f"Dataset contains {df.shape[0]} sentences")
48
+ embeddings = df[sorted([col for col in df.columns if col.startswith("dim")], key=lambda x: int(x.split("_")[-1]))].values
49
+ logger.info(f"Running t-SNE")
50
+ tsne_embeddings = get_tsne_embeddings(embeddings)
51
+ logger.info(f"Generating figure")
52
+ plot = draw_interactive_scatter_plot(df["sentence"].values, tsne_embeddings[:, 0], tsne_embeddings[:, 1], df["perplexity"].values)
53
+ output_file(output_file_name)
54
+ save(plot)
55
+
56
+
57
+
58
+
59
+ if __name__ == "__main__":
60
+ parser = argparse.ArgumentParser(description="Embeddings t-SNE plot")
61
+ parser.add_argument("--tsv", type=str, help="Path to tsv file with columns 'text', 'perplexity' and N 'dim_<i> columns for each embdeding dimension.'")
62
+ parser.add_argument("--output_file", type=str, help="Path to the output HTML file for the interactive plot.", default="perplexity_colored_embeddings.html")
63
+ parser.add_argument("--sample", type=int, help="Number of sentences to use", default=None)
64
+
65
+ args = parser.parse_args()
66
+ generate_plot(args.tsv, args.output_file, args.sample)