File size: 6,164 Bytes
2841b26 1fc08db 2841b26 f5ab4cb 2841b26 f5ab4cb 3b7b628 f5ab4cb 3b7b628 f5ab4cb 3b7b628 f5ab4cb 3b7b628 f5ab4cb 3b7b628 8dfc799 3b7b628 f5ab4cb 3c2fc33 cd47483 3c2fc33 f5ab4cb 2841b26 0d14ea5 2841b26 371c76b 3b7b628 b129294 f5ab4cb 3b7b628 b129294 f5ab4cb 3b7b628 f5ab4cb b129294 f5ab4cb 3b7b628 b129294 3b7b628 371c76b 3b7b628 f5ab4cb b129294 f5ab4cb b129294 3b7b628 371c76b 3b7b628 f5ab4cb b129294 8dfc799 b129294 3b7b628 8dfc799 371c76b 3b7b628 8dfc799 f5ab4cb 3b7b628 371c76b 3b7b628 f5ab4cb 371c76b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import math
import random
from distilabel.models import ClientvLLM, InferenceEndpointsLLM, OllamaLLM, OpenAILLM
from distilabel.steps.tasks import TextGeneration
from synthetic_dataset_generator.constants import (
API_KEYS,
DEFAULT_BATCH_SIZE,
HUGGINGFACE_BASE_URL,
HUGGINGFACE_BASE_URL_COMPLETION,
MODEL,
MODEL_COMPLETION,
OLLAMA_BASE_URL,
OLLAMA_BASE_URL_COMPLETION,
OPENAI_BASE_URL,
OPENAI_BASE_URL_COMPLETION,
TOKENIZER_ID,
TOKENIZER_ID_COMPLETION,
VLLM_BASE_URL,
VLLM_BASE_URL_COMPLETION,
)
TOKEN_INDEX = 0
def _get_next_api_key():
global TOKEN_INDEX
api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
TOKEN_INDEX += 1
return api_key
def _get_prompt_rewriter():
generation_kwargs = {
"temperature": 1,
}
system_prompt = "You are a prompt rewriter. You are given a prompt and you need to rewrite it keeping the same structure but highlighting different aspects of the original without adding anything new."
prompt_rewriter = TextGeneration(
llm=_get_llm(generation_kwargs=generation_kwargs),
system_prompt=system_prompt,
use_system_prompt=True,
)
prompt_rewriter.load()
return prompt_rewriter
def get_rewritten_prompts(prompt: str, num_rows: int):
prompt_rewriter = _get_prompt_rewriter()
# create prompt rewrites
inputs = [
{"instruction": f"Original prompt: {prompt} \nRewritten prompt: "}
for i in range(math.floor(num_rows / 100))
]
n_processed = 0
prompt_rewrites = [prompt]
while n_processed < num_rows:
batch = list(
prompt_rewriter.process(
inputs=inputs[n_processed : n_processed + DEFAULT_BATCH_SIZE]
)
)
prompt_rewrites += [entry["generation"] for entry in batch[0]]
n_processed += DEFAULT_BATCH_SIZE
random.seed(a=random.randint(0, 2**32 - 1))
return prompt_rewrites
def _get_llm_class() -> str:
if OPENAI_BASE_URL:
return "OpenAILLM"
elif OLLAMA_BASE_URL:
return "OllamaLLM"
elif HUGGINGFACE_BASE_URL:
return "InferenceEndpointsLLM"
elif VLLM_BASE_URL:
return "ClientvLLM"
else:
return "InferenceEndpointsLLM"
def _get_llm(
structured_output: dict = None,
use_magpie_template: str = False,
is_completion: bool = False,
**kwargs,
):
model = MODEL_COMPLETION if is_completion else MODEL
tokenizer_id = TOKENIZER_ID_COMPLETION if is_completion else TOKENIZER_ID or model
base_urls = {
"openai": OPENAI_BASE_URL_COMPLETION if is_completion else OPENAI_BASE_URL,
"ollama": OLLAMA_BASE_URL_COMPLETION if is_completion else OLLAMA_BASE_URL,
"huggingface": HUGGINGFACE_BASE_URL_COMPLETION if is_completion else HUGGINGFACE_BASE_URL,
"vllm": VLLM_BASE_URL_COMPLETION if is_completion else VLLM_BASE_URL,
}
if base_urls["openai"]:
llm = OpenAILLM(
model=model,
base_url=base_urls["openai"],
api_key=_get_next_api_key(),
structured_output=structured_output,
**kwargs,
)
if "generation_kwargs" in kwargs:
if "stop_sequences" in kwargs["generation_kwargs"]:
kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
"stop_sequences"
]
del kwargs["generation_kwargs"]["stop_sequences"]
if "do_sample" in kwargs["generation_kwargs"]:
del kwargs["generation_kwargs"]["do_sample"]
elif base_urls["ollama"]:
if "generation_kwargs" in kwargs:
if "max_new_tokens" in kwargs["generation_kwargs"]:
kwargs["generation_kwargs"]["num_predict"] = kwargs[
"generation_kwargs"
]["max_new_tokens"]
del kwargs["generation_kwargs"]["max_new_tokens"]
if "stop_sequences" in kwargs["generation_kwargs"]:
kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
"stop_sequences"
]
del kwargs["generation_kwargs"]["stop_sequences"]
if "do_sample" in kwargs["generation_kwargs"]:
del kwargs["generation_kwargs"]["do_sample"]
options = kwargs["generation_kwargs"]
del kwargs["generation_kwargs"]
kwargs["generation_kwargs"] = {}
kwargs["generation_kwargs"]["options"] = options
llm = OllamaLLM(
model=model,
host=base_urls["ollama"],
tokenizer_id=tokenizer_id,
use_magpie_template=use_magpie_template,
structured_output=structured_output,
**kwargs,
)
elif base_urls["huggingface"]:
kwargs["generation_kwargs"]["do_sample"] = True
llm = InferenceEndpointsLLM(
api_key=_get_next_api_key(),
base_url=base_urls["huggingface"],
tokenizer_id=tokenizer_id,
use_magpie_template=use_magpie_template,
structured_output=structured_output,
**kwargs,
)
elif base_urls["vllm"]:
if "generation_kwargs" in kwargs:
if "do_sample" in kwargs["generation_kwargs"]:
del kwargs["generation_kwargs"]["do_sample"]
llm = ClientvLLM(
base_url=base_urls["vllm"],
model=model,
tokenizer=tokenizer_id,
api_key=_get_next_api_key(),
use_magpie_template=use_magpie_template,
structured_output=structured_output,
**kwargs,
)
else:
llm = InferenceEndpointsLLM(
api_key=_get_next_api_key(),
tokenizer_id=tokenizer_id,
model_id=model,
use_magpie_template=use_magpie_template,
structured_output=structured_output,
**kwargs,
)
return llm
try:
llm = _get_llm()
llm.load()
llm.generate([[{"content": "Hello, world!", "role": "user"}]])
except Exception as e:
raise Exception(f"Error loading {llm.__class__.__name__}: {e}")
|