File size: 2,256 Bytes
5532825 1fc08db 5532825 1fc08db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import warnings
import distilabel
import distilabel.distiset
from distilabel.models import InferenceEndpointsLLM
from pydantic import (
ValidationError,
model_validator,
)
class CustomInferenceEndpointsLLM(InferenceEndpointsLLM):
@model_validator(mode="after") # type: ignore
def only_one_of_model_id_endpoint_name_or_base_url_provided(
self,
) -> "InferenceEndpointsLLM":
"""Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
favour of the dynamically calculated one.."""
if self.base_url and (self.model_id or self.endpoint_name):
warnings.warn( # type: ignore
f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
" or `endpoint_name` is also provided, the `base_url` will either be ignored"
" or overwritten with the one generated from either of those args, for serverless"
" or dedicated inference endpoints, respectively."
)
if self.use_magpie_template and self.tokenizer_id is None:
raise ValueError(
"`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
" set a `tokenizer_id` and try again."
)
if (
self.model_id
and self.tokenizer_id is None
and self.structured_output is not None
):
self.tokenizer_id = self.model_id
if self.base_url and not (self.model_id or self.endpoint_name):
return self
if self.model_id and not self.endpoint_name:
return self
if self.endpoint_name and not self.model_id:
return self
raise ValidationError(
f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
)
distilabel.models.llms.InferenceEndpointsLLM = CustomInferenceEndpointsLLM
|