File size: 4,549 Bytes
86e673e
 
 
 
7b62017
86e673e
 
 
3c30fa3
 
 
 
86e673e
 
3c30fa3
 
 
 
 
 
 
 
 
 
 
 
 
 
86e673e
3c30fa3
86e673e
 
 
 
 
 
 
 
 
 
7b62017
 
 
 
 
 
 
 
 
 
86e673e
7b62017
 
 
 
 
 
 
86e673e
3c30fa3
 
 
 
86e673e
 
 
 
7b62017
 
 
 
 
3c30fa3
7b62017
86e673e
 
 
 
 
 
 
7b62017
 
 
86e673e
 
3c30fa3
 
 
 
 
 
 
 
86e673e
 
7b62017
86e673e
 
 
 
7b62017
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c30fa3
7b62017
 
 
 
 
 
 
3c30fa3
7b62017
3c30fa3
 
7b62017
3c30fa3
 
 
 
 
7b62017
86e673e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import logging
from functools import partial
from typing import Optional

import pandas as pd
import typer
from bokeh.plotting import output_file as bokeh_output_file
from bokeh.plotting import save
from embedding_lenses.dimensionality_reduction import (
    get_tsne_embeddings,
    get_umap_embeddings,
)
from embedding_lenses.embedding import load_model

from perplexity_lenses import REGISTRY_DATASET
from perplexity_lenses.data import (
    documents_df_to_sentences_df,
    hub_dataset_to_dataframe,
)
from perplexity_lenses.engine import (
    DIMENSIONALITY_REDUCTION_ALGORITHMS,
    DOCUMENT_TYPES,
    EMBEDDING_MODELS,
    LANGUAGES,
    PERPLEXITY_MODELS,
    SEED,
    generate_plot,
)
from perplexity_lenses.perplexity import KenlmModel
from perplexity_lenses.visualization import draw_histogram

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


app = typer.Typer()


@app.command()
def main(
    dataset: str = typer.Option(
        "mc4", help="The name of the hub dataset or local csv/tsv file."
    ),
    dataset_config: Optional[str] = typer.Option(
        "es",
        help="The configuration of the hub dataset, if any. Does not apply to local csv/tsv files.",
    ),
    dataset_split: Optional[str] = typer.Option(
        "train", help="The dataset split. Does not apply to local csv/tsv files."
    ),
    text_column: str = typer.Option("text", help="The text field name."),
    language: str = typer.Option(
        "es", help=f"The language of the text. Options: {LANGUAGES}"
    ),
    doc_type: str = typer.Option(
        "sentence",
        help=f"Whether to embed at the sentence or document level. Options: {DOCUMENT_TYPES}.",
    ),
    sample: int = typer.Option(1000, help="Maximum number of examples to use."),
    perplexity_model: str = typer.Option(
        "wikipedia",
        help=f"Dataset on which the perplexity model was trained on. Options: {PERPLEXITY_MODELS}",
    ),
    dimensionality_reduction: str = typer.Option(
        DIMENSIONALITY_REDUCTION_ALGORITHMS[0],
        help=f"Whether to use UMAP or t-SNE for dimensionality reduction. Options: {DIMENSIONALITY_REDUCTION_ALGORITHMS}.",
    ),
    model_name: str = typer.Option(
        EMBEDDING_MODELS[0],
        help=f"The sentence embedding model to use. Options: {EMBEDDING_MODELS}",
    ),
    output_file: str = typer.Option(
        "perplexity", help="The name of the output visualization files."
    ),
):
    """
    Perplexity Lenses: Visualize text embeddings in 2D using colors to represent perplexity values.
    """
    logger.info("Loading embedding model...")
    model = load_model(model_name)
    dimensionality_reduction_function = (
        partial(get_umap_embeddings, random_state=SEED)
        if dimensionality_reduction.lower() == "umap"
        else partial(get_tsne_embeddings, random_state=SEED)
    )
    logger.info("Loading KenLM model...")
    kenlm_model = KenlmModel.from_pretrained(
        perplexity_model.lower(),
        language,
        lower_case=True,
        remove_accents=True,
        normalize_numbers=True,
        punctuation=1,
    )
    logger.info("Loading dataset...")
    if dataset.endswith(".csv") or dataset.endswith(".tsv"):
        df = pd.read_csv(dataset, sep="\t" if dataset.endswith(".tsv") else ",")
        if doc_type.lower() == "sentence":
            df = documents_df_to_sentences_df(df, text_column, sample, seed=SEED)
        df["perplexity"] = df[text_column].map(kenlm_model.get_perplexity)
    else:
        df = hub_dataset_to_dataframe(
            dataset,
            dataset_config,
            dataset_split,
            sample,
            text_column,
            kenlm_model,
            seed=SEED,
            doc_type=doc_type,
        )
    # Round perplexity
    df["perplexity"] = df["perplexity"].round().astype(int)
    logger.info(
        f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}"
    )
    plot, plot_registry = generate_plot(
        df,
        text_column,
        "perplexity",
        None,
        dimensionality_reduction_function,
        model,
        seed=SEED,
        hub_dataset=dataset,
    )
    logger.info("Saving plots")
    bokeh_output_file(f"{output_file}.html")
    save(plot)
    if dataset == REGISTRY_DATASET:
        bokeh_output_file(f"{output_file}_registry.html")
        save(plot_registry)
    fig = draw_histogram(df["perplexity"].values)
    fig.savefig(f"{output_file}_histogram.png")
    logger.info("Done")


if __name__ == "__main__":
    app()