versae commited on
Commit
8b9ba87
2 Parent(s): 853cd83 d5cede4

Adjust batch size for extrating tokens

Browse files
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:50c50c05859f43aa6a08aa3106a1ca62d225f1ac927d57e0e86e422cff5ee7a7
3
- size 711588089
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ff31ebb2460dbc41a160cc755d0555bb8c84672563808b968a2a121c1b2414a
3
+ size 711587941
get_embeddings_and_perplexity.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import kenlm
3
+ from datasets import load_dataset
4
+ from tqdm import tqdm
5
+ import pandas as pd
6
+ import numpy as np
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+
10
+ TOTAL_SENTENCES = 20000
11
+ def pp(log_score, length):
12
+ return 10.0 ** (-log_score / length)
13
+
14
+
15
+ embedder = "distiluse-base-multilingual-cased-v1"
16
+ embedder_model = SentenceTransformer(embedder)
17
+ embedding_shape = embedder_model.encode(["foo"])[0].shape[0]
18
+ # http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
19
+ model = kenlm.Model("es.arpa.bin")
20
+ mc4 = load_dataset("mc4", "es", streaming=True)
21
+ count = 0
22
+ embeddings = []
23
+ lenghts = []
24
+ perplexities = []
25
+ sentences = []
26
+
27
+ for sample in tqdm(mc4["train"].shuffle(buffer_size=100_000), total=416057992):
28
+ lines = sample["text"].split("\n")
29
+ for line in lines:
30
+ count += 1
31
+ log_score = model.score(line)
32
+ length = len(line.split()) + 1
33
+ embedding = embedder_model.encode([line])[0]
34
+ embeddings.append(embedding.tolist())
35
+ perplexities.append(pp(log_score, length))
36
+ lenghts.append(length)
37
+ sentences.append(line)
38
+ if count == TOTAL_SENTENCES:
39
+ break
40
+ if count == TOTAL_SENTENCES:
41
+ embeddings = np.array(embeddings)
42
+ df = pd.DataFrame({"sentence": sentences, "length": lenghts, "perplexity": perplexities})
43
+ for dim in range(embedding_shape):
44
+ df[f"dim_{dim}"] = embeddings[:, dim]
45
+ df.to_csv("mc4-es-perplexity-sentences.tsv", index=None, sep="\t")
46
+ print("DONE!")
47
+ break
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4265b625a915f8a622926c9be27d6b1f3f2bc44481f81ab5d53eace54a0bc06
3
+ size 1421780139
tokens.py.orig ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from datasets import load_dataset
3
+ from tokenizers import ByteLevelBPETokenizer
4
+
5
+ # Load dataset
6
+ <<<<<<< HEAD
7
+ dataset = load_dataset("oscar", "unshuffled_deduplicated_es", split="train[:5000000]")
8
+
9
+ # Instantiate tokenizer
10
+ tokenizer = ByteLevelBPETokenizer()
11
+ def batch_iterator(batch_size=100_000):
12
+ =======
13
+ dataset = load_dataset("oscar", "unshuffled_deduplicated_es", split="train")
14
+
15
+ # Instantiate tokenizer
16
+ tokenizer = ByteLevelBPETokenizer()
17
+ def batch_iterator(batch_size=1_000_000):
18
+ >>>>>>> d5cede47e74aa6ec36f20acf5aba37c6734c6186
19
+ for i in range(0, len(dataset), batch_size):
20
+ yield dataset["text"][i: i + batch_size]
21
+
22
+ # Customized training
23
+ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
24
+ "<s>",
25
+ "<pad>",
26
+ "</s>",
27
+ "<unk>",
28
+ "<mask>",
29
+ ])
30
+ # Save files to disk
31
+ tokenizer.save("./tokenizer.json")
tsne_plot.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ from typing import Any, Optional
4
+
5
+ import bokeh
6
+ import numpy as np
7
+ import pandas as pd
8
+ from bokeh.models import ColumnDataSource, HoverTool
9
+ from bokeh.plotting import figure, output_file, save
10
+ from bokeh.transform import factor_cmap
11
+ from bokeh.palettes import Cividis256 as Pallete
12
+ from sklearn.manifold import TSNE
13
+
14
+
15
+ logging.basicConfig(level = logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+ SEED = 0
18
+
19
+ def get_tsne_embeddings(embeddings: np.ndarray, perplexity: int=30, n_components: int=2, init: str='pca', n_iter: int=5000, random_state: int=SEED) -> np.ndarray:
20
+ tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state)
21
+ return tsne.fit_transform(embeddings)
22
+
23
+ def draw_interactive_scatter_plot(texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray) -> Any:
24
+ # Normalize values to range between 0-255, to assign a color for each value
25
+ max_value = values.max()
26
+ min_value = values.min()
27
+ values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str)
28
+ values_color_set = sorted(values_color)
29
+
30
+ values_list = values.astype(str).tolist()
31
+ values_set = sorted(values_list)
32
+
33
+ source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, perplexity=values_list))
34
+ hover = HoverTool(tooltips=[('Sentence', '@text{safe}'), ('Perplexity', '@perplexity')])
35
+ p = figure(plot_width=1200, plot_height=1200, tools=[hover], title='Sentences')
36
+ p.circle(
37
+ 'x', 'y', size=10, source=source, fill_color=factor_cmap('perplexity', palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set))
38
+ return p
39
+
40
+ def generate_plot(tsv: str, output_file_name: str, sample: Optional[int]):
41
+ logger.info("Loading dataset in memory")
42
+ df = pd.read_csv(tsv, sep="\t")
43
+ if sample:
44
+ df = df.sample(sample, random_state=SEED)
45
+ logger.info(f"Dataset contains {df.shape[0]} sentences")
46
+ embeddings = df[sorted([col for col in df.columns if col.startswith("dim")], key=lambda x: int(x.split("_")[-1]))].values
47
+ logger.info(f"Running t-SNE")
48
+ tsne_embeddings = get_tsne_embeddings(embeddings)
49
+ logger.info(f"Generating figure")
50
+ plot = draw_interactive_scatter_plot(df["sentence"].values, tsne_embeddings[:, 0], tsne_embeddings[:, 1], df["perplexity"].values)
51
+ output_file(output_file_name)
52
+ save(plot)
53
+
54
+
55
+
56
+
57
+ if __name__ == "__main__":
58
+ parser = argparse.ArgumentParser(description="Embeddings t-SNE plot")
59
+ parser.add_argument("--tsv", type=str, help="Path to tsv file with columns 'text', 'perplexity' and N 'dim_<i> columns for each embdeding dimension.'")
60
+ parser.add_argument("--output_file", type=str, help="Path to the output HTML file for the interactive plot.", default="perplexity_colored_embeddings.html")
61
+ parser.add_argument("--sample", type=int, help="Number of sentences to use", default=None)
62
+
63
+ args = parser.parse_args()
64
+ generate_plot(args.tsv, args.output_file, args.sample)