Spaces:
Runtime error
Runtime error
File size: 1,984 Bytes
d134af5 a9cc2b2 d134af5 a9cc2b2 d134af5 a9cc2b2 d134af5 a9cc2b2 01d5d64 a9cc2b2 d134af5 a9cc2b2 d134af5 a9cc2b2 d134af5 a9cc2b2 d134af5 a9cc2b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import dash
import plotly.express as px
from dash import dcc, html
from dash.dependencies import Input, Output
from dash.exceptions import PreventUpdate
from datasets import load_dataset
# Create dash app
app = dash.Dash(__name__)
def get_dataset(name, n_items=1000):
ola_path = f"ola13/small-{name}-dedup"
dataset = load_dataset(ola_path, split="train").shuffle().select(range(n_items)).to_pandas()
dataset["text_length"] = dataset.apply(lambda doc: len(doc["text"]), axis=1)
for column in dataset.columns:
if column not in ["text", "perplexity", "text_length"]:
dataset = dataset.drop(column, axis=1)
dataset = dataset.sort_values("perplexity")
max_perp = dataset["perplexity"].max()
return dataset, max_perp
# names = ["oscar", "the_pile", "c4", "roots_en"]
name = "c4"
df, max_perplexity = get_dataset(name)
# Create scatter plot with x and y coordinates
fig = px.scatter(df, x="perplexity", y="text_length", custom_data=["text"])
# Update layout and update traces
fig.update_layout(clickmode='event+select')
fig.update_traces(marker_size=3)
fig.update_xaxes(title_text="Perplexity (log scale)", type="log")
fig.update_yaxes(title_text="Text Length (log scale)", type="log")
styles = {
'textbox': {
'border': 'thin lightgrey solid',
'overflowX': 'scroll',
"whiteSpace": "pre-wrap;"
}
}
# Create app layout to show dash graph
app.layout = html.Div(
[
dcc.Graph(
id="graph_interaction",
figure=fig,
),
html.Div(id='text', style=styles['textbox'])
]
)
# html callback function to hover the data on specific coordinates
@app.callback(
Output('text', 'children'),
Input('graph_interaction', 'hoverData'))
def open_url(hoverData):
if hoverData:
return hoverData["points"][0]["customdata"][0]
else:
raise PreventUpdate
if __name__ == '__main__':
app.run_server(port=7860, host="0.0.0.0", debug=True)
|