edugp commited on
Commit
a86046b
1 Parent(s): bf3498e

Support visualizing both sentences and whole documents. Smooth down color assignment in visualization.

Browse files
app.py CHANGED
@@ -6,15 +6,14 @@ import pandas as pd
6
  import streamlit as st
7
  from bokeh.plotting import Figure
8
  from embedding_lenses.data import uploaded_file_to_dataframe
9
- from embedding_lenses.dimensionality_reduction import (get_tsne_embeddings,
10
- get_umap_embeddings)
11
  from embedding_lenses.embedding import embed_text, load_model
12
  from embedding_lenses.utils import encode_labels
13
- from embedding_lenses.visualization import draw_interactive_scatter_plot
14
  from sentence_transformers import SentenceTransformer
15
 
16
- from data import hub_dataset_to_dataframe
17
- from perplexity import KenlmModel
 
18
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
@@ -70,6 +69,7 @@ LANGUAGES = [
70
  "uk",
71
  "zh",
72
  ]
 
73
  SEED = 0
74
 
75
 
@@ -113,9 +113,18 @@ with col2:
113
  with col3:
114
  hub_dataset_split = st.text_input("Dataset split", "train")
115
 
116
- text_column = st.text_input("Text field name", "text")
117
- language = st.selectbox("Language", LANGUAGES, 12)
118
- sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
 
 
 
 
 
 
 
 
 
119
  dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0)
120
  model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0)
121
 
@@ -132,10 +141,16 @@ if uploaded_file or hub_dataset:
132
  with st.spinner("Loading dataset..."):
133
  if uploaded_file:
134
  df = uploaded_file_to_dataframe(uploaded_file)
 
 
135
  df["perplexity"] = df[text_column].map(kenlm_model.get_perplexity)
136
  else:
137
- df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, sample, text_column, kenlm_model, seed=SEED)
138
- plot = generate_plot(df, text_column, "perplexity", sample, dimensionality_reduction_function, model)
 
 
 
 
139
  logger.info("Displaying plot")
140
  st.bokeh_chart(plot)
141
  logger.info("Done")
 
6
  import streamlit as st
7
  from bokeh.plotting import Figure
8
  from embedding_lenses.data import uploaded_file_to_dataframe
9
+ from embedding_lenses.dimensionality_reduction import get_tsne_embeddings, get_umap_embeddings
 
10
  from embedding_lenses.embedding import embed_text, load_model
11
  from embedding_lenses.utils import encode_labels
 
12
  from sentence_transformers import SentenceTransformer
13
 
14
+ from perplexity_lenses.data import documents_df_to_sentences_df, hub_dataset_to_dataframe
15
+ from perplexity_lenses.perplexity import KenlmModel
16
+ from perplexity_lenses.visualization import draw_interactive_scatter_plot
17
 
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
 
69
  "uk",
70
  "zh",
71
  ]
72
+ DOCUMENT_TYPES = ["Whole document", "Sentence"]
73
  SEED = 0
74
 
75
 
 
113
  with col3:
114
  hub_dataset_split = st.text_input("Dataset split", "train")
115
 
116
+ col4, col5 = st.columns(2)
117
+ with col4:
118
+ text_column = st.text_input("Text field name", "text")
119
+ with col5:
120
+ language = st.selectbox("Language", LANGUAGES, 12)
121
+
122
+ col6, col7 = st.columns(2)
123
+ with col6:
124
+ doc_type = st.selectbox("Document type", DOCUMENT_TYPES, 1)
125
+ with col7:
126
+ sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
127
+
128
  dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0)
129
  model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0)
130
 
 
141
  with st.spinner("Loading dataset..."):
142
  if uploaded_file:
143
  df = uploaded_file_to_dataframe(uploaded_file)
144
+ if doc_type == "Sentence":
145
+ df = documents_df_to_sentences_df(df, text_column, sample, seed=SEED)
146
  df["perplexity"] = df[text_column].map(kenlm_model.get_perplexity)
147
  else:
148
+ df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, sample, text_column, kenlm_model, seed=SEED, doc_type=doc_type)
149
+
150
+ # Round perplexity
151
+ df["perplexity"] = df["perplexity"].round().astype(int)
152
+ logger.info(f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}")
153
+ plot = generate_plot(df, text_column, "perplexity", None, dimensionality_reduction_function, model)
154
  logger.info("Displaying plot")
155
  st.bokeh_chart(plot)
156
  logger.info("Done")
data.py DELETED
@@ -1,28 +0,0 @@
1
- from functools import partial
2
-
3
- import pandas as pd
4
- from datasets import load_dataset
5
- from tqdm import tqdm
6
-
7
- from perplexity import KenlmModel
8
-
9
-
10
- def hub_dataset_to_dataframe(path: str, name: str, split: str, sample: int, text_column: str, model: KenlmModel, seed: int = 0) -> pd.DataFrame:
11
- load_dataset_fn = partial(load_dataset, path=path)
12
- if name:
13
- load_dataset_fn = partial(load_dataset_fn, name=name)
14
- if split:
15
- load_dataset_fn = partial(load_dataset_fn, split=split)
16
- dataset = (
17
- load_dataset_fn(streaming=True)
18
- .shuffle(buffer_size=10000, seed=seed)
19
- .map(lambda x: {text_column: x[text_column], "perplexity": model.get_perplexity(x[text_column])})
20
- )
21
- instances = []
22
- count = 0
23
- for instance in tqdm(dataset, total=sample):
24
- instances.append(instance)
25
- count += 1
26
- if count == sample:
27
- break
28
- return pd.DataFrame(instances)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
perplexity_lenses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.0"
perplexity_lenses/data.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from datasets import load_dataset
6
+ from tqdm import tqdm
7
+
8
+ from perplexity_lenses.perplexity import KenlmModel
9
+
10
+
11
+ def hub_dataset_to_dataframe(
12
+ path: str, name: str, split: str, sample: int, text_column: str, model: KenlmModel, seed: int = 0, doc_type: str = "Whole document"
13
+ ) -> pd.DataFrame:
14
+ load_dataset_fn = partial(load_dataset, path=path)
15
+ if name:
16
+ load_dataset_fn = partial(load_dataset_fn, name=name)
17
+ if split:
18
+ load_dataset_fn = partial(load_dataset_fn, split=split)
19
+ dataset = load_dataset_fn(streaming=True).shuffle(buffer_size=10000, seed=seed)
20
+ if doc_type == "Sentence":
21
+ dataset = dataset.map(lambda x: [{text_column: sentence, "perplexity": model.get_perplexity(sentence)} for sentence in x[text_column].split("\n")])
22
+ else:
23
+ dataset = dataset.map(lambda x: {text_column: x[text_column], "perplexity": model.get_perplexity(x[text_column])})
24
+ instances = []
25
+ count = 0
26
+ for instance in tqdm(dataset, total=sample):
27
+ if isinstance(instance, list):
28
+ for sentence in instance:
29
+ instances.append(sentence)
30
+ count += 1
31
+ if count == sample:
32
+ break
33
+ else:
34
+ instances.append(instance)
35
+ count += 1
36
+ if count == sample:
37
+ break
38
+ return pd.DataFrame(instances)
39
+
40
+
41
+ def documents_df_to_sentences_df(df: pd.DataFrame, text_column: str, sample: int, seed: int = 0):
42
+ df_sentences = pd.DataFrame({text_column: np.array(df[text_column].map(lambda x: x.split("\n")).values.tolist()).flatten()})
43
+ return df_sentences.sample(min(sample, df.shape[0]), random_state=seed)
perplexity.py → perplexity_lenses/perplexity.py RENAMED
File without changes
perplexity_lenses/visualization.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from bokeh.models import ColumnDataSource, HoverTool
3
+ from bokeh.palettes import Cividis256 as Pallete
4
+ from bokeh.plotting import Figure, figure
5
+ from bokeh.transform import factor_cmap
6
+
7
+
8
+ def draw_interactive_scatter_plot(
9
+ texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
10
+ ) -> Figure:
11
+ # Smooth down values for coloring, by taking the entropy = log10(perplexity) and multiply it by 10000
12
+ values = ((np.log10(values)) * 10000).round().astype(int)
13
+ # Normalize values to range between 0-255, to assign a color for each value
14
+ max_value = values.max()
15
+ min_value = values.min()
16
+ if max_value - min_value == 0:
17
+ values_color = np.ones(len(values))
18
+ else:
19
+ values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int)
20
+ values_color_sorted = sorted(values_color)
21
+
22
+ values_list = values.astype(str).tolist()
23
+ values_sorted = sorted(values_list)
24
+ labels_list = labels.astype(str).tolist()
25
+
26
+ source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list))
27
+ hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
28
+ p = figure(plot_width=800, plot_height=800, tools=[hover])
29
+ p.circle("x", "y", size=10, source=source, fill_color=factor_cmap("label", palette=[Pallete[id_] for id_ in values_color_sorted], factors=values_sorted))
30
+
31
+ p.axis.visible = False
32
+ p.xgrid.grid_line_color = None
33
+ p.ygrid.grid_line_color = None
34
+ p.toolbar.logo = None
35
+ return p