Fix BASE_URL_COMPLETIONS when using different providers
Browse files
src/synthetic_dataset_generator/pipelines/base.py
CHANGED
@@ -87,10 +87,17 @@ def _get_llm(
|
|
87 |
):
|
88 |
model = MODEL_COMPLETION if is_completion else MODEL
|
89 |
tokenizer_id = TOKENIZER_ID_COMPLETION if is_completion else TOKENIZER_ID or model
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
llm = OpenAILLM(
|
92 |
model=model,
|
93 |
-
base_url=
|
94 |
api_key=_get_next_api_key(),
|
95 |
structured_output=structured_output,
|
96 |
**kwargs,
|
@@ -103,7 +110,7 @@ def _get_llm(
|
|
103 |
del kwargs["generation_kwargs"]["stop_sequences"]
|
104 |
if "do_sample" in kwargs["generation_kwargs"]:
|
105 |
del kwargs["generation_kwargs"]["do_sample"]
|
106 |
-
elif
|
107 |
if "generation_kwargs" in kwargs:
|
108 |
if "max_new_tokens" in kwargs["generation_kwargs"]:
|
109 |
kwargs["generation_kwargs"]["num_predict"] = kwargs[
|
@@ -123,32 +130,28 @@ def _get_llm(
|
|
123 |
kwargs["generation_kwargs"]["options"] = options
|
124 |
llm = OllamaLLM(
|
125 |
model=model,
|
126 |
-
host=
|
127 |
tokenizer_id=tokenizer_id,
|
128 |
use_magpie_template=use_magpie_template,
|
129 |
structured_output=structured_output,
|
130 |
**kwargs,
|
131 |
)
|
132 |
-
elif
|
133 |
kwargs["generation_kwargs"]["do_sample"] = True
|
134 |
llm = InferenceEndpointsLLM(
|
135 |
api_key=_get_next_api_key(),
|
136 |
-
base_url=
|
137 |
-
HUGGINGFACE_BASE_URL_COMPLETION
|
138 |
-
if is_completion
|
139 |
-
else HUGGINGFACE_BASE_URL
|
140 |
-
),
|
141 |
tokenizer_id=tokenizer_id,
|
142 |
use_magpie_template=use_magpie_template,
|
143 |
structured_output=structured_output,
|
144 |
**kwargs,
|
145 |
)
|
146 |
-
elif
|
147 |
if "generation_kwargs" in kwargs:
|
148 |
if "do_sample" in kwargs["generation_kwargs"]:
|
149 |
del kwargs["generation_kwargs"]["do_sample"]
|
150 |
llm = ClientvLLM(
|
151 |
-
base_url=
|
152 |
model=model,
|
153 |
tokenizer=tokenizer_id,
|
154 |
api_key=_get_next_api_key(),
|
|
|
87 |
):
|
88 |
model = MODEL_COMPLETION if is_completion else MODEL
|
89 |
tokenizer_id = TOKENIZER_ID_COMPLETION if is_completion else TOKENIZER_ID or model
|
90 |
+
base_urls = {
|
91 |
+
"openai": OPENAI_BASE_URL_COMPLETION if is_completion else OPENAI_BASE_URL,
|
92 |
+
"ollama": OLLAMA_BASE_URL_COMPLETION if is_completion else OLLAMA_BASE_URL,
|
93 |
+
"huggingface": HUGGINGFACE_BASE_URL_COMPLETION if is_completion else HUGGINGFACE_BASE_URL,
|
94 |
+
"vllm": VLLM_BASE_URL_COMPLETION if is_completion else VLLM_BASE_URL,
|
95 |
+
}
|
96 |
+
|
97 |
+
if base_urls["openai"]:
|
98 |
llm = OpenAILLM(
|
99 |
model=model,
|
100 |
+
base_url=base_urls["openai"],
|
101 |
api_key=_get_next_api_key(),
|
102 |
structured_output=structured_output,
|
103 |
**kwargs,
|
|
|
110 |
del kwargs["generation_kwargs"]["stop_sequences"]
|
111 |
if "do_sample" in kwargs["generation_kwargs"]:
|
112 |
del kwargs["generation_kwargs"]["do_sample"]
|
113 |
+
elif base_urls["ollama"]:
|
114 |
if "generation_kwargs" in kwargs:
|
115 |
if "max_new_tokens" in kwargs["generation_kwargs"]:
|
116 |
kwargs["generation_kwargs"]["num_predict"] = kwargs[
|
|
|
130 |
kwargs["generation_kwargs"]["options"] = options
|
131 |
llm = OllamaLLM(
|
132 |
model=model,
|
133 |
+
host=base_urls["ollama"],
|
134 |
tokenizer_id=tokenizer_id,
|
135 |
use_magpie_template=use_magpie_template,
|
136 |
structured_output=structured_output,
|
137 |
**kwargs,
|
138 |
)
|
139 |
+
elif base_urls["huggingface"]:
|
140 |
kwargs["generation_kwargs"]["do_sample"] = True
|
141 |
llm = InferenceEndpointsLLM(
|
142 |
api_key=_get_next_api_key(),
|
143 |
+
base_url=base_urls["huggingface"],
|
|
|
|
|
|
|
|
|
144 |
tokenizer_id=tokenizer_id,
|
145 |
use_magpie_template=use_magpie_template,
|
146 |
structured_output=structured_output,
|
147 |
**kwargs,
|
148 |
)
|
149 |
+
elif base_urls["vllm"]:
|
150 |
if "generation_kwargs" in kwargs:
|
151 |
if "do_sample" in kwargs["generation_kwargs"]:
|
152 |
del kwargs["generation_kwargs"]["do_sample"]
|
153 |
llm = ClientvLLM(
|
154 |
+
base_url=base_urls["vllm"],
|
155 |
model=model,
|
156 |
tokenizer=tokenizer_id,
|
157 |
api_key=_get_next_api_key(),
|