update push dataset
Browse files
src/distilabel_dataset_generator/apps/sft.py
CHANGED
@@ -71,14 +71,14 @@ def generate_system_prompt(dataset_description, temperature, progress=gr.Progres
|
|
71 |
|
72 |
|
73 |
def generate_sample_dataset(system_prompt, num_turns, progress=gr.Progress()):
|
74 |
-
|
75 |
system_prompt=system_prompt,
|
76 |
num_turns=num_turns,
|
77 |
num_rows=10,
|
78 |
progress=progress,
|
79 |
is_sample=True,
|
80 |
)
|
81 |
-
return
|
82 |
|
83 |
|
84 |
def generate_dataset(
|
@@ -202,7 +202,7 @@ def push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private):
|
|
202 |
return original_dataframe
|
203 |
|
204 |
|
205 |
-
def
|
206 |
org_name: str,
|
207 |
repo_name: str,
|
208 |
system_prompt: str,
|
@@ -400,7 +400,10 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
|
|
400 |
)
|
401 |
with gr.Column(scale=3):
|
402 |
dataframe = gr.Dataframe(
|
403 |
-
headers=["prompt", "completion"],
|
|
|
|
|
|
|
404 |
)
|
405 |
|
406 |
gr.HTML(value="<hr>")
|
@@ -445,6 +448,7 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
|
|
445 |
label="Distilabel Pipeline Code",
|
446 |
)
|
447 |
|
|
|
448 |
load_btn.click(
|
449 |
fn=generate_system_prompt,
|
450 |
inputs=[dataset_description, temperature],
|
@@ -456,7 +460,8 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
|
|
456 |
outputs=[dataframe],
|
457 |
show_progress=True,
|
458 |
)
|
459 |
-
|
|
|
460 |
btn_apply_to_sample_dataset.click(
|
461 |
fn=generate_sample_dataset,
|
462 |
inputs=[system_prompt, num_turns],
|
@@ -479,7 +484,7 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
|
|
479 |
outputs=[success_message],
|
480 |
show_progress=True,
|
481 |
).success(
|
482 |
-
fn=
|
483 |
inputs=[
|
484 |
org_name,
|
485 |
repo_name,
|
@@ -500,5 +505,6 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
|
|
500 |
outputs=[pipeline_code_ui],
|
501 |
)
|
502 |
|
|
|
503 |
app.load(fn=swap_visibility, outputs=main_ui)
|
504 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
|
|
71 |
|
72 |
|
73 |
def generate_sample_dataset(system_prompt, num_turns, progress=gr.Progress()):
|
74 |
+
dataframe = generate_dataset(
|
75 |
system_prompt=system_prompt,
|
76 |
num_turns=num_turns,
|
77 |
num_rows=10,
|
78 |
progress=progress,
|
79 |
is_sample=True,
|
80 |
)
|
81 |
+
return dataframe
|
82 |
|
83 |
|
84 |
def generate_dataset(
|
|
|
202 |
return original_dataframe
|
203 |
|
204 |
|
205 |
+
def push_dataset(
|
206 |
org_name: str,
|
207 |
repo_name: str,
|
208 |
system_prompt: str,
|
|
|
400 |
)
|
401 |
with gr.Column(scale=3):
|
402 |
dataframe = gr.Dataframe(
|
403 |
+
headers=["prompt", "completion"],
|
404 |
+
wrap=True,
|
405 |
+
height=500,
|
406 |
+
interactive=False,
|
407 |
)
|
408 |
|
409 |
gr.HTML(value="<hr>")
|
|
|
448 |
label="Distilabel Pipeline Code",
|
449 |
)
|
450 |
|
451 |
+
|
452 |
load_btn.click(
|
453 |
fn=generate_system_prompt,
|
454 |
inputs=[dataset_description, temperature],
|
|
|
460 |
outputs=[dataframe],
|
461 |
show_progress=True,
|
462 |
)
|
463 |
+
|
464 |
+
|
465 |
btn_apply_to_sample_dataset.click(
|
466 |
fn=generate_sample_dataset,
|
467 |
inputs=[system_prompt, num_turns],
|
|
|
484 |
outputs=[success_message],
|
485 |
show_progress=True,
|
486 |
).success(
|
487 |
+
fn=push_dataset,
|
488 |
inputs=[
|
489 |
org_name,
|
490 |
repo_name,
|
|
|
505 |
outputs=[pipeline_code_ui],
|
506 |
)
|
507 |
|
508 |
+
|
509 |
app.load(fn=swap_visibility, outputs=main_ui)
|
510 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
src/distilabel_dataset_generator/apps/textcat.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import json
|
2 |
-
import re
|
3 |
import uuid
|
4 |
from typing import List, Union
|
5 |
|
@@ -189,7 +188,7 @@ def push_dataset_to_hub(
|
|
189 |
)
|
190 |
|
191 |
|
192 |
-
def
|
193 |
org_name: str,
|
194 |
repo_name: str,
|
195 |
system_prompt: str,
|
@@ -509,7 +508,7 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
|
|
509 |
outputs=[success_message],
|
510 |
show_progress=True,
|
511 |
).success(
|
512 |
-
fn=
|
513 |
inputs=[
|
514 |
org_name,
|
515 |
repo_name,
|
|
|
1 |
import json
|
|
|
2 |
import uuid
|
3 |
from typing import List, Union
|
4 |
|
|
|
188 |
)
|
189 |
|
190 |
|
191 |
+
def push_dataset(
|
192 |
org_name: str,
|
193 |
repo_name: str,
|
194 |
system_prompt: str,
|
|
|
508 |
outputs=[success_message],
|
509 |
show_progress=True,
|
510 |
).success(
|
511 |
+
fn=push_dataset,
|
512 |
inputs=[
|
513 |
org_name,
|
514 |
repo_name,
|