|
import os |
|
from functools import cached_property |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
from lm_eval.api.registry import register_model |
|
from lm_eval.models.api_models import TemplateAPI |
|
from lm_eval.utils import eval_logger |
|
|
|
|
|
@register_model("local-completions") |
|
class LocalCompletionsAPI(TemplateAPI): |
|
def __init__( |
|
self, |
|
base_url=None, |
|
tokenizer_backend="huggingface", |
|
**kwargs, |
|
): |
|
super().__init__( |
|
base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs |
|
) |
|
|
|
def _create_payload( |
|
self, |
|
messages: Union[List[List[int]], List[dict], List[str], str], |
|
generate=False, |
|
gen_kwargs: Optional[dict] = None, |
|
seed: int = 1234, |
|
**kwargs, |
|
) -> dict: |
|
if generate: |
|
gen_kwargs.pop("do_sample", False) |
|
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks) |
|
temperature = gen_kwargs.pop("temperature", 0) |
|
stop = gen_kwargs.pop("until", ["<|endoftext|>"]) |
|
return { |
|
"prompt": messages, |
|
"model": self.model, |
|
"max_tokens": max_tokens, |
|
"temperature": temperature, |
|
"stop": stop, |
|
"seed": seed, |
|
**gen_kwargs, |
|
} |
|
else: |
|
return { |
|
"model": self.model, |
|
"prompt": messages, |
|
"temperature": 0, |
|
"max_tokens": 1, |
|
"logprobs": 1, |
|
"seed": seed, |
|
"echo": True, |
|
} |
|
|
|
@staticmethod |
|
def parse_logprobs( |
|
outputs: Union[Dict, List[Dict]], |
|
tokens: List[List[int]] = None, |
|
ctxlens: List[int] = None, |
|
**kwargs, |
|
) -> List[Tuple[float, bool]]: |
|
res = [] |
|
if not isinstance(outputs, list): |
|
outputs = [outputs] |
|
for out in outputs: |
|
for choice, ctxlen in zip(out["choices"], ctxlens): |
|
assert ctxlen > 0, "Context length must be greater than 0" |
|
logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1]) |
|
tokens = choice["logprobs"]["token_logprobs"][ctxlen:-1] |
|
top_logprobs = choice["logprobs"]["top_logprobs"][ctxlen:-1] |
|
is_greedy = True |
|
for tok, top in zip(tokens, top_logprobs): |
|
if tok != max(top, key=top.get): |
|
is_greedy = False |
|
break |
|
res.append((logprobs, is_greedy)) |
|
return res |
|
|
|
@staticmethod |
|
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]: |
|
res = [] |
|
if not isinstance(outputs, list): |
|
outputs = [outputs] |
|
for out in outputs: |
|
for choices in out["choices"]: |
|
res.append(choices["text"]) |
|
return res |
|
|
|
@property |
|
def api_key(self): |
|
return os.environ.get("OPENAI_API_KEY", "") |
|
|
|
|
|
@register_model("local-chat-completions") |
|
class LocalChatCompletion(LocalCompletionsAPI): |
|
def __init__( |
|
self, |
|
base_url=None, |
|
tokenizer_backend=None, |
|
tokenized_requests=False, |
|
**kwargs, |
|
): |
|
eval_logger.warning( |
|
"chat-completions endpoint requires the `--apply_chat_template` flag." |
|
) |
|
super().__init__( |
|
base_url=base_url, |
|
tokenizer_backend=tokenizer_backend, |
|
tokenized_requests=tokenized_requests, |
|
**kwargs, |
|
) |
|
if self._batch_size > 1: |
|
eval_logger.warning( |
|
"Chat completions does not support batching. Defaulting to batch size 1." |
|
) |
|
self._batch_size = 1 |
|
|
|
def _create_payload( |
|
self, |
|
messages: List[Dict], |
|
generate=False, |
|
gen_kwargs: dict = None, |
|
seed=1234, |
|
**kwargs, |
|
) -> dict: |
|
gen_kwargs.pop("do_sample", False) |
|
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks) |
|
temperature = gen_kwargs.pop("temperature", 0) |
|
stop = gen_kwargs.pop("until", ["<|endoftext|>"]) |
|
if not isinstance(stop, (list, tuple)): |
|
stop = [stop] |
|
return { |
|
"messages": messages, |
|
"model": self.model, |
|
"max_tokens": max_tokens, |
|
"temperature": temperature, |
|
"stop": stop[:4], |
|
"seed": seed, |
|
**gen_kwargs, |
|
} |
|
|
|
@staticmethod |
|
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]: |
|
res = [] |
|
if not isinstance(outputs, list): |
|
outputs = [outputs] |
|
for out in outputs: |
|
for choices in out["choices"]: |
|
res.append(choices["message"]["content"]) |
|
return res |
|
|
|
def tok_encode( |
|
self, |
|
string: Union[str, Any], |
|
left_truncate_len=None, |
|
add_special_tokens=None, |
|
**kwargs, |
|
) -> Union[List[str], List[int], Any]: |
|
return string |
|
|
|
def loglikelihood(self, requests, **kwargs): |
|
raise NotImplementedError( |
|
"Loglikelihood is not supported for chat completions. Consider using the completions API instead." |
|
) |
|
|
|
|
|
@register_model( |
|
"openai-completions", |
|
) |
|
class OpenAICompletionsAPI(LocalCompletionsAPI): |
|
def __init__( |
|
self, |
|
base_url="https://api.openai.com/v1/completions", |
|
tokenizer_backend="tiktoken", |
|
**kwargs, |
|
): |
|
super().__init__( |
|
base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs |
|
) |
|
|
|
@cached_property |
|
def api_key(self): |
|
"""Override this property to return the API key for the API request.""" |
|
key = os.environ.get("OPENAI_API_KEY", None) |
|
if key is None: |
|
raise ValueError( |
|
"API key not found. Please set the OPENAI_API_KEY environment variable." |
|
) |
|
return key |
|
|
|
def loglikelihood(self, requests, **kwargs): |
|
assert ( |
|
self.model != "gpt-3.5-turbo" |
|
), "Loglikelihood is not supported for gpt-3.5-turbo" |
|
return super().loglikelihood(requests, **kwargs) |
|
|
|
|
|
@register_model("openai-chat-completions") |
|
class OpenAIChatCompletion(LocalChatCompletion): |
|
def __init__( |
|
self, |
|
base_url="https://api.openai.com/v1/chat/completions", |
|
tokenizer_backend=None, |
|
tokenized_requests=False, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
base_url=base_url, |
|
tokenizer_backend=tokenizer_backend, |
|
tokenized_requests=tokenized_requests, |
|
**kwargs, |
|
) |
|
|
|
@cached_property |
|
def api_key(self): |
|
"""Override this property to return the API key for the API request.""" |
|
key = os.environ.get("OPENAI_API_KEY", None) |
|
if key is None: |
|
raise ValueError( |
|
"API key not found. Please set the OPENAI_API_KEY environment variable." |
|
) |
|
return key |
|
|