| import torch |
|
|
| import numpy as np |
| import torch.nn.functional as F |
|
|
| from lm_eval.base import BaseLM |
| from datasets import load_dataset |
|
|
|
|
| def set_seed(seed): |
| np.random.seed(seed) |
| torch.random.manual_seed(seed) |
|
|
| def get_test_dataset(dataset_name, tokenizer, seqlen=2048): |
| if dataset_name == "wikitext2": |
| testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') |
| testdata = "".join(testdata['text']).split('\n') |
| elif dataset_name == "c4": |
| testdata = load_dataset('allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')['text'] |
| else: |
| raise NotImplementedError |
| |
| testdata = [item for item in testdata if item != ""] |
| tokenized_text = [tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] for item in testdata] |
|
|
| data, doc = [], [tokenizer.bos_token_id] |
| for sen in tokenized_text: |
| if len(sen) > seqlen: |
| continue |
| if len(doc) + len(sen) > seqlen: |
| data.append(doc) |
| doc = [tokenizer.bos_token_id] |
| doc.extend(sen) |
| if len(doc) > 1 and len(doc) <= seqlen: |
| data.append(doc) |
| return data |
|
|
|
|
| class LMEvalAdaptor(BaseLM): |
| def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1): |
| super().__init__() |
|
|
| assert isinstance(batch_size, int) |
|
|
| self.model_name = model_name |
| self.model = model |
| self.model.eval() |
|
|
| self.tokenizer = tokenizer |
|
|
| self.vocab_size = self.tokenizer.vocab_size |
|
|
| self._batch_size = batch_size |
|
|
| self._max_length = max_length |
|
|
| @property |
| def eot_token_id(self): |
| |
| return self.tokenizer.eos_token_id |
|
|
| @property |
| def max_length(self): |
| if self._max_length != -1: |
| return self._max_length |
| if hasattr(self.model.config, "n_ctx"): |
| return self.model.config.n_ctx |
| elif hasattr(self.model.config, "max_position_embeddings"): |
| return self.model.config.max_position_embeddings |
| elif hasattr(self.model.config, "n_positions"): |
| return self.model.config.n_positions |
| elif "bloom" in self.model_name: |
| return 2048 |
| elif "llama" in self.model_name: |
| return 2048 |
| elif "mpt" in self.model_name: |
| return 2048 |
| elif "falcon" in self.model_name: |
| return 2048 |
| else: |
| print(self.model.config) |
| raise NotImplementedError |
|
|
| @property |
| def max_gen_toks(self): |
| return 256 |
|
|
| @property |
| def batch_size(self): |
| return self._batch_size |
|
|
| @property |
| def device(self): |
| return "cuda" |
|
|
| def tok_encode(self, string: str, add_special_tokens=True): |
| return self.tokenizer.encode(string, add_special_tokens=add_special_tokens) |
|
|
| def tok_decode(self, tokens): |
| return self.tokenizer.decode(tokens) |
|
|
| def loglikelihood(self, requests): |
| new_reqs = [] |
| for context, continuation in requests: |
| context, continuation = context.strip(), continuation.strip() |
| if context == "": |
| |
| context_enc = [self.eot_token_id] |
| else: |
| context_enc = self.tok_encode(context, add_special_tokens=True) |
|
|
| continuation_enc = self.tok_encode(continuation, add_special_tokens=False) |
|
|
| new_reqs.append(((context, continuation), context_enc, continuation_enc)) |
|
|
| return self._loglikelihood_tokens(new_reqs) |
|
|
| def _model_call(self, inps): |
| """ |
| inps: a torch tensor of shape [batch, sequence] |
| the size of sequence may vary from call to call |
| |
| returns: a torch tensor of shape [batch, sequence, vocab] with the |
| logits returned from the model |
| """ |
| with torch.no_grad(): |
| out = self.model(inps)[0] |
| return out |
|
|
| def _model_generate(self, context, max_length, eos_token_id): |
| return self.model.generate( |
| context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False |
| ) |