polinaeterna HF staff commited on
Commit
b1d4b4a
β€’
1 Parent(s): 0a44dc6
Files changed (1) hide show
  1. app.py +20 -12
app.py CHANGED
@@ -103,7 +103,7 @@ def run_quality_check(dataset, column, batch_size, num_examples):
103
  batch_predictions = predict(batch_texts)
104
  predictions.extend(batch_predictions)
105
  texts_processed.extend(batch_texts)
106
- yield {"check in progress...": min(i+batch_size, num_examples) / num_examples}, *plot_and_df(texts_processed, predictions), pd.DataFrame()
107
 
108
  # with multiprocessing.Pool(processes=8) as pool:
109
  # props = pool.map(proportion_non_ascii, texts)
@@ -130,22 +130,21 @@ def plot_toxicity(scores):
130
  fig, axs = plt.subplots(2, 3)#, figsize=(10, 6))
131
  for x, y, score_name in zip([0,0,0,1,1,1], [0,1,2,0,1,2], scores):
132
  axs[x,y].hist(scores[score_name], bins=20, range=(0., 1.))
133
- # axs[x,y].set_title(f'Histogram of {score_name}')
134
- axs[x,y].set_xlabel(f'{score_name}')
135
- # axs[x,y].set_ylabel('Number of texts')
136
  fig.supylabel("Number of texts")
137
  fig.suptitle("Histogram of toxicity scores")
138
  fig.tight_layout()
139
 
140
  return fig
141
 
142
- def call_perspective_api(texts_df, column_name):#, s):
143
  headers = {
144
  "content-type": "application/json",
145
  }
146
  req_att_scores = {attr: [] for attr in REQUESTED_ATTRIBUTES}
147
 
148
- texts = texts_df[column_name].values
 
149
  n_samples = len(texts)
150
  for i, text in tqdm(enumerate(texts), desc="scanning with perspective"):
151
  data = {
@@ -184,7 +183,8 @@ def call_perspective_api(texts_df, column_name):#, s):
184
  return req_att_scores
185
  if i % 10 == 0:
186
  plot_toxicity(req_att_scores)
187
- yield {"toxicity check in progress...": i / n_samples}, plt.gcf(), pd.DataFrame.from_dict({column_name: texts[:i], **req_att_scores})
 
188
 
189
  plot_toxicity(req_att_scores)
190
  yield {"toxicity check finished.": 1.}, plt.gcf(), pd.DataFrame.from_dict({column_name: texts, **req_att_scores})
@@ -224,6 +224,7 @@ with gr.Blocks() as demo:
224
  """
225
  # πŸ’« Dataset Quality Checker πŸ’«
226
  Use [nvidia/quality-classifier-deberta](https://huggingface.co/nvidia/quality-classifier-deberta) on any text dataset on the Hub.
 
227
  """
228
  )
229
  dataset_name = HuggingfaceHubSearch(
@@ -247,6 +248,8 @@ with gr.Blocks() as demo:
247
  return gr.HTML(value=html_code)
248
 
249
  text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)")
 
 
250
  batch_size = gr.Slider(0, 128, 32, step=8, label="Inference batch size (set this to smaller value if this space crashes.)")
251
  num_examples = gr.Number(500, label="Number of first examples to check")
252
  gr_check_btn = gr.Button("Check Dataset")
@@ -262,18 +265,23 @@ with gr.Blocks() as demo:
262
  gr.Markdown("### High")
263
  df_high = gr.DataFrame()
264
 
265
- texts_sample_df = gr.DataFrame(visible=False)
266
  gr_check_btn.click(
267
  run_quality_check,
268
  inputs=[dataset_name, text_column, batch_size, num_examples],
269
- outputs=[progress_bar, plot, df_low, df_medium, df_high, texts_sample_df]
270
  )
271
 
272
- gr_ascii_btn = gr.Button("Non ascii chars.")
 
 
 
273
  non_ascii_hist = gr.Plot()
274
 
275
- gr_ascii_btn.click(non_ascii_check, inputs=[texts_sample_df, text_column], outputs=[non_ascii_hist])
276
 
 
 
277
  gr_toxicity_btn = gr.Button("Run perpspective API to check toxicity of random samples.")
278
  toxicity_progress_bar = gr.Label(show_label=False)
279
  toxicity_hist = gr.Plot()
@@ -281,7 +289,7 @@ with gr.Blocks() as demo:
281
  toxicity_df = gr.DataFrame()
282
  gr_toxicity_btn.click(
283
  call_perspective_api,
284
- inputs=[texts_sample_df, text_column],
285
  outputs=[toxicity_progress_bar, toxicity_hist, toxicity_df]
286
  )
287
 
 
103
  batch_predictions = predict(batch_texts)
104
  predictions.extend(batch_predictions)
105
  texts_processed.extend(batch_texts)
106
+ yield {"check in progress...": i / num_examples}, *plot_and_df(texts_processed, predictions), pd.DataFrame()
107
 
108
  # with multiprocessing.Pool(processes=8) as pool:
109
  # props = pool.map(proportion_non_ascii, texts)
 
130
  fig, axs = plt.subplots(2, 3)#, figsize=(10, 6))
131
  for x, y, score_name in zip([0,0,0,1,1,1], [0,1,2,0,1,2], scores):
132
  axs[x,y].hist(scores[score_name], bins=20, range=(0., 1.))
133
+ axs[x,y].set_xlabel(score_name)
 
 
134
  fig.supylabel("Number of texts")
135
  fig.suptitle("Histogram of toxicity scores")
136
  fig.tight_layout()
137
 
138
  return fig
139
 
140
+ def call_perspective_api(texts_df, column_name, full_check=False):
141
  headers = {
142
  "content-type": "application/json",
143
  }
144
  req_att_scores = {attr: [] for attr in REQUESTED_ATTRIBUTES}
145
 
146
+ texts = texts_df.sample(100, random_state=16)[column_name].values if not full_check else texts_df[column_name].values
147
+
148
  n_samples = len(texts)
149
  for i, text in tqdm(enumerate(texts), desc="scanning with perspective"):
150
  data = {
 
183
  return req_att_scores
184
  if i % 10 == 0:
185
  plot_toxicity(req_att_scores)
186
+ print(len(texts[:i]), len(req_att_scores["TOXICITY"]))
187
+ yield {"toxicity check in progress...": i / n_samples}, plt.gcf(), pd.DataFrame.from_dict({column_name: texts[:i+1], **req_att_scores})
188
 
189
  plot_toxicity(req_att_scores)
190
  yield {"toxicity check finished.": 1.}, plt.gcf(), pd.DataFrame.from_dict({column_name: texts, **req_att_scores})
 
224
  """
225
  # πŸ’« Dataset Quality Checker πŸ’«
226
  Use [nvidia/quality-classifier-deberta](https://huggingface.co/nvidia/quality-classifier-deberta) on any text dataset on the Hub.
227
+ ## Select dataset and text column
228
  """
229
  )
230
  dataset_name = HuggingfaceHubSearch(
 
248
  return gr.HTML(value=html_code)
249
 
250
  text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)")
251
+
252
+ gr.Markdown("## Run nvidia quality classifier")
253
  batch_size = gr.Slider(0, 128, 32, step=8, label="Inference batch size (set this to smaller value if this space crashes.)")
254
  num_examples = gr.Number(500, label="Number of first examples to check")
255
  gr_check_btn = gr.Button("Check Dataset")
 
265
  gr.Markdown("### High")
266
  df_high = gr.DataFrame()
267
 
268
+ texts_df = gr.DataFrame(visible=False)
269
  gr_check_btn.click(
270
  run_quality_check,
271
  inputs=[dataset_name, text_column, batch_size, num_examples],
272
+ outputs=[progress_bar, plot, df_low, df_medium, df_high, texts_df]
273
  )
274
 
275
+ gr.Markdown("""## Compute text quality measures
276
+ * proportion of non-ascii characters
277
+ * #TODO""")
278
+ gr_ascii_btn = gr.Button("Data measures")
279
  non_ascii_hist = gr.Plot()
280
 
281
+ gr_ascii_btn.click(non_ascii_check, inputs=[texts_df, text_column], outputs=[non_ascii_hist])
282
 
283
+ gr.Markdown("## Explore toxicity")
284
+ checkbox = gr.Checkbox(value=False, label="Run on full first parquet data (better not)")
285
  gr_toxicity_btn = gr.Button("Run perpspective API to check toxicity of random samples.")
286
  toxicity_progress_bar = gr.Label(show_label=False)
287
  toxicity_hist = gr.Plot()
 
289
  toxicity_df = gr.DataFrame()
290
  gr_toxicity_btn.click(
291
  call_perspective_api,
292
+ inputs=[texts_df, text_column, checkbox],
293
  outputs=[toxicity_progress_bar, toxicity_hist, toxicity_df]
294
  )
295