sdiazlor HF staff commited on
Commit
b129294
·
1 Parent(s): 784963b

Fix BASE_URL_COMPLETIONS when using different providers

Browse files
src/synthetic_dataset_generator/pipelines/base.py CHANGED
@@ -87,10 +87,17 @@ def _get_llm(
87
  ):
88
  model = MODEL_COMPLETION if is_completion else MODEL
89
  tokenizer_id = TOKENIZER_ID_COMPLETION if is_completion else TOKENIZER_ID or model
90
- if OPENAI_BASE_URL:
 
 
 
 
 
 
 
91
  llm = OpenAILLM(
92
  model=model,
93
- base_url=OPENAI_BASE_URL_COMPLETION if is_completion else OPENAI_BASE_URL,
94
  api_key=_get_next_api_key(),
95
  structured_output=structured_output,
96
  **kwargs,
@@ -103,7 +110,7 @@ def _get_llm(
103
  del kwargs["generation_kwargs"]["stop_sequences"]
104
  if "do_sample" in kwargs["generation_kwargs"]:
105
  del kwargs["generation_kwargs"]["do_sample"]
106
- elif OLLAMA_BASE_URL:
107
  if "generation_kwargs" in kwargs:
108
  if "max_new_tokens" in kwargs["generation_kwargs"]:
109
  kwargs["generation_kwargs"]["num_predict"] = kwargs[
@@ -123,32 +130,28 @@ def _get_llm(
123
  kwargs["generation_kwargs"]["options"] = options
124
  llm = OllamaLLM(
125
  model=model,
126
- host=OLLAMA_BASE_URL_COMPLETION if is_completion else OLLAMA_BASE_URL,
127
  tokenizer_id=tokenizer_id,
128
  use_magpie_template=use_magpie_template,
129
  structured_output=structured_output,
130
  **kwargs,
131
  )
132
- elif HUGGINGFACE_BASE_URL:
133
  kwargs["generation_kwargs"]["do_sample"] = True
134
  llm = InferenceEndpointsLLM(
135
  api_key=_get_next_api_key(),
136
- base_url=(
137
- HUGGINGFACE_BASE_URL_COMPLETION
138
- if is_completion
139
- else HUGGINGFACE_BASE_URL
140
- ),
141
  tokenizer_id=tokenizer_id,
142
  use_magpie_template=use_magpie_template,
143
  structured_output=structured_output,
144
  **kwargs,
145
  )
146
- elif VLLM_BASE_URL:
147
  if "generation_kwargs" in kwargs:
148
  if "do_sample" in kwargs["generation_kwargs"]:
149
  del kwargs["generation_kwargs"]["do_sample"]
150
  llm = ClientvLLM(
151
- base_url=VLLM_BASE_URL_COMPLETION if is_completion else VLLM_BASE_URL,
152
  model=model,
153
  tokenizer=tokenizer_id,
154
  api_key=_get_next_api_key(),
 
87
  ):
88
  model = MODEL_COMPLETION if is_completion else MODEL
89
  tokenizer_id = TOKENIZER_ID_COMPLETION if is_completion else TOKENIZER_ID or model
90
+ base_urls = {
91
+ "openai": OPENAI_BASE_URL_COMPLETION if is_completion else OPENAI_BASE_URL,
92
+ "ollama": OLLAMA_BASE_URL_COMPLETION if is_completion else OLLAMA_BASE_URL,
93
+ "huggingface": HUGGINGFACE_BASE_URL_COMPLETION if is_completion else HUGGINGFACE_BASE_URL,
94
+ "vllm": VLLM_BASE_URL_COMPLETION if is_completion else VLLM_BASE_URL,
95
+ }
96
+
97
+ if base_urls["openai"]:
98
  llm = OpenAILLM(
99
  model=model,
100
+ base_url=base_urls["openai"],
101
  api_key=_get_next_api_key(),
102
  structured_output=structured_output,
103
  **kwargs,
 
110
  del kwargs["generation_kwargs"]["stop_sequences"]
111
  if "do_sample" in kwargs["generation_kwargs"]:
112
  del kwargs["generation_kwargs"]["do_sample"]
113
+ elif base_urls["ollama"]:
114
  if "generation_kwargs" in kwargs:
115
  if "max_new_tokens" in kwargs["generation_kwargs"]:
116
  kwargs["generation_kwargs"]["num_predict"] = kwargs[
 
130
  kwargs["generation_kwargs"]["options"] = options
131
  llm = OllamaLLM(
132
  model=model,
133
+ host=base_urls["ollama"],
134
  tokenizer_id=tokenizer_id,
135
  use_magpie_template=use_magpie_template,
136
  structured_output=structured_output,
137
  **kwargs,
138
  )
139
+ elif base_urls["huggingface"]:
140
  kwargs["generation_kwargs"]["do_sample"] = True
141
  llm = InferenceEndpointsLLM(
142
  api_key=_get_next_api_key(),
143
+ base_url=base_urls["huggingface"],
 
 
 
 
144
  tokenizer_id=tokenizer_id,
145
  use_magpie_template=use_magpie_template,
146
  structured_output=structured_output,
147
  **kwargs,
148
  )
149
+ elif base_urls["vllm"]:
150
  if "generation_kwargs" in kwargs:
151
  if "do_sample" in kwargs["generation_kwargs"]:
152
  del kwargs["generation_kwargs"]["do_sample"]
153
  llm = ClientvLLM(
154
+ base_url=base_urls["vllm"],
155
  model=model,
156
  tokenizer=tokenizer_id,
157
  api_key=_get_next_api_key(),