texcat-generation_review (#13)
Browse files- bug: correct generation parameters (2a1af1a85c406f402dbfec69ed1bdd13e0aa5066)
src/distilabel_dataset_generator/apps/base.py
CHANGED
@@ -427,6 +427,7 @@ def push_dataset_to_hub(
|
|
427 |
|
428 |
if task == TEXTCAT_TASK:
|
429 |
if num_labels == 1:
|
|
|
430 |
features = Features(
|
431 |
{"text": Value("string"), "label": ClassLabel(names=labels)}
|
432 |
)
|
|
|
427 |
|
428 |
if task == TEXTCAT_TASK:
|
429 |
if num_labels == 1:
|
430 |
+
dataframe["label"] = dataframe["label"].replace("", None)
|
431 |
features = Features(
|
432 |
{"text": Value("string"), "label": ClassLabel(names=labels)}
|
433 |
)
|
src/distilabel_dataset_generator/apps/textcat.py
CHANGED
@@ -53,6 +53,9 @@ def push_dataset_to_hub(
|
|
53 |
num_labels: int = 1,
|
54 |
):
|
55 |
original_dataframe = dataframe.copy(deep=True)
|
|
|
|
|
|
|
56 |
labels = get_preprocess_labels(labels)
|
57 |
try:
|
58 |
push_to_hub_base(
|
@@ -80,6 +83,9 @@ def push_dataset_to_argilla(
|
|
80 |
labels: List[str] = None,
|
81 |
) -> pd.DataFrame:
|
82 |
original_dataframe = dataframe.copy(deep=True)
|
|
|
|
|
|
|
83 |
try:
|
84 |
progress(0.1, desc="Setting up user and workspace")
|
85 |
client = get_argilla_client()
|
|
|
53 |
num_labels: int = 1,
|
54 |
):
|
55 |
original_dataframe = dataframe.copy(deep=True)
|
56 |
+
dataframe = dataframe[
|
57 |
+
(dataframe["text"].str.strip() != "") & (dataframe["text"].notna())
|
58 |
+
]
|
59 |
labels = get_preprocess_labels(labels)
|
60 |
try:
|
61 |
push_to_hub_base(
|
|
|
83 |
labels: List[str] = None,
|
84 |
) -> pd.DataFrame:
|
85 |
original_dataframe = dataframe.copy(deep=True)
|
86 |
+
dataframe = dataframe[
|
87 |
+
(dataframe["text"].str.strip() != "") & (dataframe["text"].notna())
|
88 |
+
]
|
89 |
try:
|
90 |
progress(0.1, desc="Setting up user and workspace")
|
91 |
client = get_argilla_client()
|
src/distilabel_dataset_generator/pipelines/textcat.py
CHANGED
@@ -114,9 +114,11 @@ with Pipeline(name="textcat") as pipeline:
|
|
114 |
"temperature": 0.8,
|
115 |
"max_new_tokens": 2048,
|
116 |
"do_sample": True,
|
117 |
-
"
|
|
|
118 |
}},
|
119 |
),
|
|
|
120 |
difficulty={None if difficulty == "mixed" else repr(difficulty)},
|
121 |
clarity={None if clarity == "mixed" else repr(clarity)},
|
122 |
num_generations={num_rows},
|
@@ -182,11 +184,13 @@ def get_textcat_generator(difficulty, clarity, is_sample):
|
|
182 |
"temperature": 0.9,
|
183 |
"max_new_tokens": 256 if is_sample else 2048,
|
184 |
"do_sample": True,
|
185 |
-
"
|
|
|
186 |
},
|
187 |
),
|
188 |
difficulty=None if difficulty == "mixed" else difficulty,
|
189 |
clarity=None if clarity == "mixed" else clarity,
|
|
|
190 |
)
|
191 |
textcat_generator.load()
|
192 |
return textcat_generator
|
|
|
114 |
"temperature": 0.8,
|
115 |
"max_new_tokens": 2048,
|
116 |
"do_sample": True,
|
117 |
+
"top_k": 50,
|
118 |
+
"top_p": 0.95,
|
119 |
}},
|
120 |
),
|
121 |
+
seed=random.randint(0, 2**32 - 1),
|
122 |
difficulty={None if difficulty == "mixed" else repr(difficulty)},
|
123 |
clarity={None if clarity == "mixed" else repr(clarity)},
|
124 |
num_generations={num_rows},
|
|
|
184 |
"temperature": 0.9,
|
185 |
"max_new_tokens": 256 if is_sample else 2048,
|
186 |
"do_sample": True,
|
187 |
+
"top_k": 50,
|
188 |
+
"top_p": 0.95,
|
189 |
},
|
190 |
),
|
191 |
difficulty=None if difficulty == "mixed" else difficulty,
|
192 |
clarity=None if clarity == "mixed" else clarity,
|
193 |
+
seed=random.randint(0, 2**32 - 1),
|
194 |
)
|
195 |
textcat_generator.load()
|
196 |
return textcat_generator
|
src/distilabel_dataset_generator/utils.py
CHANGED
@@ -124,4 +124,4 @@ def get_argilla_client() -> Union[rg.Argilla, None]:
|
|
124 |
return None
|
125 |
|
126 |
def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
|
127 |
-
return [label.lower().strip() for label in labels] if labels else []
|
|
|
124 |
return None
|
125 |
|
126 |
def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
|
127 |
+
return list(set([label.lower().strip() for label in labels])) if labels else []
|