davidberenstein1957 HF staff commited on
Commit
4106f96
1 Parent(s): b6646ba

fix openai compatability

Browse files
README.md CHANGED
@@ -82,13 +82,15 @@ Optionally, you can set the following environment variables to customize the gen
82
  - `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
83
  - `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
84
 
85
- Optionally, you can use different models and APIs.
86
 
87
- - `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api-inference.huggingface.co/v1/`, `https://api.openai.com/v1/`.
88
- - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`.
89
  - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
90
- - `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie. Llama3 and Qwen2 are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"` respectively. For other models, you can pass a custom pre-query template string.
91
 
 
 
 
92
 
93
  Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
94
 
 
82
  - `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
83
  - `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
84
 
85
+ Optionally, you can use different models and APIs. For providers outside of Hugging Face, we provide an integration through [LiteLLM](https://docs.litellm.ai/docs/providers).
86
 
87
+ - `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api-inference.huggingface.co/v1/`, `https://api.openai.com/v1/`, `http://127.0.0.1:11434/v1/`.
88
+ - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `openai/gpt-4o`, `ollama/llama3.1`.
89
  - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
 
90
 
91
+ SFT and Chat Data generation is only supported with Hugging Face Inference Endpoints , and you can set the following environment variables use it with models other than Llama3 and Qwen2.
92
+
93
+ - `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie, which is only supported with Hugging Face Inference Endpoints. Llama3 and Qwen2 are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"` respectively. For other models, you can pass a custom pre-query template string.
94
 
95
  Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
96
 
app.py CHANGED
@@ -1,3 +1,8 @@
 
 
1
  from synthetic_dataset_generator import launch
2
 
 
 
 
3
  launch()
 
1
+ import os
2
+
3
  from synthetic_dataset_generator import launch
4
 
5
+ os.environ["BASE_URL"] = "http://localhost:11434/v1"
6
+ os.environ["MODEL"] = "llama3.1"
7
+
8
  launch()
src/synthetic_dataset_generator/apps/chat.py CHANGED
@@ -20,6 +20,7 @@ from synthetic_dataset_generator.apps.base import (
20
  validate_push_to_hub,
21
  )
22
  from synthetic_dataset_generator.constants import (
 
23
  DEFAULT_BATCH_SIZE,
24
  MODEL,
25
  SFT_AVAILABLE,
@@ -413,8 +414,8 @@ with gr.Blocks() as app:
413
  [
414
  "## Supervised Fine-Tuning not available",
415
  "",
416
- f"This tool relies on the [Magpie](https://arxiv.org/abs/2406.08464) prequery template, which is not implemented for the {MODEL} model.",
417
- "Use Llama3 or Qwen2 models or [implement another magpie prequery template](https://github.com/argilla-io/distilabel/pull/778/files).",
418
  ]
419
  )
420
  )
 
20
  validate_push_to_hub,
21
  )
22
  from synthetic_dataset_generator.constants import (
23
+ BASE_URL,
24
  DEFAULT_BATCH_SIZE,
25
  MODEL,
26
  SFT_AVAILABLE,
 
414
  [
415
  "## Supervised Fine-Tuning not available",
416
  "",
417
+ f"This tool relies on the [Magpie](https://arxiv.org/abs/2406.08464) prequery template, which is not implemented for the {MODEL} with {BASE_URL}.",
418
+ "Use Llama3 or Qwen2 models with Hugging Face Inference Endpoints.",
419
  ]
420
  )
421
  )
src/synthetic_dataset_generator/constants.py CHANGED
@@ -19,6 +19,8 @@ MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
19
  MAX_NUM_ROWS: str | int = int(os.getenv("MAX_NUM_ROWS", 1000))
20
  DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
21
  MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
 
 
22
  _API_KEY = os.getenv("API_KEY")
23
  if _API_KEY:
24
  API_KEYS = [_API_KEY]
@@ -27,12 +29,9 @@ else:
27
  os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)
28
  ]
29
  API_KEYS = [token for token in API_KEYS if token]
30
- BASE_URL = os.getenv("BASE_URL", "https://api-inference.huggingface.co/v1/")
31
 
32
- if BASE_URL != "https://api-inference.huggingface.co/v1/" and len(API_KEYS) == 0:
33
- raise ValueError(
34
- "API_KEY is not set. Ensure you have set the API_KEY environment variable that has access to the Hugging Face Inference Endpoints."
35
- )
36
  llama_options = ["llama3", "llama-3", "llama 3"]
37
  qwen_options = ["qwen2", "qwen-2", "qwen 2"]
38
  if os.getenv("MAGPIE_PRE_QUERY_TEMPLATE"):
@@ -54,14 +53,16 @@ elif MODEL.lower() in qwen_options or any(
54
  ):
55
  SFT_AVAILABLE = True
56
  MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
57
- else:
 
58
  SFT_AVAILABLE = False
 
 
59
  warnings.warn(
60
- "`SFT_AVAILABLE` is set to `False` because the model is not a Qwen or Llama model."
61
  )
62
  MAGPIE_PRE_QUERY_TEMPLATE = None
63
 
64
-
65
  # Embeddings
66
  STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
67
 
 
19
  MAX_NUM_ROWS: str | int = int(os.getenv("MAX_NUM_ROWS", 1000))
20
  DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
21
  MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
22
+ BASE_URL = os.getenv("BASE_URL", default=None)
23
+
24
  _API_KEY = os.getenv("API_KEY")
25
  if _API_KEY:
26
  API_KEYS = [_API_KEY]
 
29
  os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)
30
  ]
31
  API_KEYS = [token for token in API_KEYS if token]
 
32
 
33
+ # Determine if SFT is available
34
+ SFT_AVAILABLE = False
 
 
35
  llama_options = ["llama3", "llama-3", "llama 3"]
36
  qwen_options = ["qwen2", "qwen-2", "qwen 2"]
37
  if os.getenv("MAGPIE_PRE_QUERY_TEMPLATE"):
 
53
  ):
54
  SFT_AVAILABLE = True
55
  MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
56
+
57
+ if BASE_URL:
58
  SFT_AVAILABLE = False
59
+
60
+ if not SFT_AVAILABLE:
61
  warnings.warn(
62
+ message="`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints to generate chat data."
63
  )
64
  MAGPIE_PRE_QUERY_TEMPLATE = None
65
 
 
66
  # Embeddings
67
  STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
68
 
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -1,7 +1,7 @@
1
  import random
2
  from typing import List
3
 
4
- from distilabel.llms import InferenceEndpointsLLM
5
  from distilabel.steps.tasks import (
6
  GenerateTextClassificationData,
7
  TextClassification,
@@ -61,39 +61,66 @@ class TextClassificationTask(BaseModel):
61
 
62
 
63
  def get_prompt_generator():
64
- prompt_generator = TextGeneration(
65
- llm=InferenceEndpointsLLM(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  api_key=_get_next_api_key(),
67
  model_id=MODEL,
68
  base_url=BASE_URL,
69
- structured_output={"format": "json", "schema": TextClassificationTask},
70
- generation_kwargs={
71
- "temperature": 0.8,
72
- "max_new_tokens": MAX_NUM_TOKENS,
73
- "do_sample": True,
74
- },
75
- ),
76
  system_prompt=PROMPT_CREATION_PROMPT,
77
  use_system_prompt=True,
78
  )
 
79
  prompt_generator.load()
80
  return prompt_generator
81
 
82
 
83
  def get_textcat_generator(difficulty, clarity, temperature, is_sample):
84
- textcat_generator = GenerateTextClassificationData(
85
- llm=InferenceEndpointsLLM(
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  model_id=MODEL,
87
  base_url=BASE_URL,
88
  api_key=_get_next_api_key(),
89
- generation_kwargs={
90
- "temperature": temperature,
91
- "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
92
- "do_sample": True,
93
- "top_k": 50,
94
- "top_p": 0.95,
95
- },
96
- ),
97
  difficulty=None if difficulty == "mixed" else difficulty,
98
  clarity=None if clarity == "mixed" else clarity,
99
  seed=random.randint(0, 2**32 - 1),
@@ -103,16 +130,28 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
103
 
104
 
105
  def get_labeller_generator(system_prompt, labels, multi_label):
106
- labeller_generator = TextClassification(
107
- llm=InferenceEndpointsLLM(
 
 
 
 
 
 
 
 
 
 
 
 
108
  model_id=MODEL,
109
  base_url=BASE_URL,
110
  api_key=_get_next_api_key(),
111
- generation_kwargs={
112
- "temperature": 0.7,
113
- "max_new_tokens": MAX_NUM_TOKENS,
114
- },
115
- ),
116
  context=system_prompt,
117
  available_labels=labels,
118
  n=len(labels) if multi_label else 1,
 
1
  import random
2
  from typing import List
3
 
4
+ from distilabel.llms import InferenceEndpointsLLM, OpenAILLM
5
  from distilabel.steps.tasks import (
6
  GenerateTextClassificationData,
7
  TextClassification,
 
61
 
62
 
63
  def get_prompt_generator():
64
+ structured_output = {
65
+ "format": "json",
66
+ "schema": TextClassificationTask,
67
+ }
68
+ generation_kwargs = {
69
+ "temperature": 0.8,
70
+ "max_new_tokens": MAX_NUM_TOKENS,
71
+ }
72
+ if BASE_URL:
73
+ llm = OpenAILLM(
74
+ model=MODEL,
75
+ base_url=BASE_URL,
76
+ api_key=_get_next_api_key(),
77
+ structured_output=structured_output,
78
+ generation_kwargs=generation_kwargs,
79
+ )
80
+ else:
81
+ generation_kwargs["do_sample"] = True
82
+ llm = InferenceEndpointsLLM(
83
  api_key=_get_next_api_key(),
84
  model_id=MODEL,
85
  base_url=BASE_URL,
86
+ structured_output=structured_output,
87
+ generation_kwargs=generation_kwargs,
88
+ )
89
+
90
+ prompt_generator = TextGeneration(
91
+ llm=llm,
 
92
  system_prompt=PROMPT_CREATION_PROMPT,
93
  use_system_prompt=True,
94
  )
95
+
96
  prompt_generator.load()
97
  return prompt_generator
98
 
99
 
100
  def get_textcat_generator(difficulty, clarity, temperature, is_sample):
101
+ generation_kwargs = {
102
+ "temperature": temperature,
103
+ "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
104
+ "top_p": 0.95,
105
+ }
106
+ if BASE_URL:
107
+ llm = OpenAILLM(
108
+ model=MODEL,
109
+ base_url=BASE_URL,
110
+ api_key=_get_next_api_key(),
111
+ generation_kwargs=generation_kwargs,
112
+ )
113
+ else:
114
+ generation_kwargs["do_sample"] = True
115
+ llm = InferenceEndpointsLLM(
116
  model_id=MODEL,
117
  base_url=BASE_URL,
118
  api_key=_get_next_api_key(),
119
+ generation_kwargs=generation_kwargs,
120
+ )
121
+
122
+ textcat_generator = GenerateTextClassificationData(
123
+ llm=llm,
 
 
 
124
  difficulty=None if difficulty == "mixed" else difficulty,
125
  clarity=None if clarity == "mixed" else clarity,
126
  seed=random.randint(0, 2**32 - 1),
 
130
 
131
 
132
  def get_labeller_generator(system_prompt, labels, multi_label):
133
+ generation_kwargs = {
134
+ "temperature": 0.01,
135
+ "max_new_tokens": MAX_NUM_TOKENS,
136
+ }
137
+
138
+ if BASE_URL:
139
+ llm = OpenAILLM(
140
+ model=MODEL,
141
+ base_url=BASE_URL,
142
+ api_key=_get_next_api_key(),
143
+ generation_kwargs=generation_kwargs,
144
+ )
145
+ else:
146
+ llm = InferenceEndpointsLLM(
147
  model_id=MODEL,
148
  base_url=BASE_URL,
149
  api_key=_get_next_api_key(),
150
+ generation_kwargs=generation_kwargs,
151
+ )
152
+
153
+ labeller_generator = TextClassification(
154
+ llm=llm,
155
  context=system_prompt,
156
  available_labels=labels,
157
  n=len(labels) if multi_label else 1,