|
import torch |
|
|
|
import numpy as np |
|
import torch.nn.functional as F |
|
|
|
from lm_eval.base import BaseLM |
|
from datasets import load_dataset |
|
|
|
|
|
def set_seed(seed): |
|
np.random.seed(seed) |
|
torch.random.manual_seed(seed) |
|
|
|
def get_test_dataset(dataset_name, tokenizer, seqlen=2048): |
|
if dataset_name == "wikitext2": |
|
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') |
|
testdata = "".join(testdata['text']).split('\n') |
|
elif dataset_name == "c4": |
|
testdata = load_dataset('allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')['text'] |
|
else: |
|
raise NotImplementedError |
|
|
|
testdata = [item for item in testdata if item != ""] |
|
tokenized_text = [tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] for item in testdata] |
|
|
|
data, doc = [], [tokenizer.bos_token_id] |
|
for sen in tokenized_text: |
|
if len(sen) > seqlen: |
|
continue |
|
if len(doc) + len(sen) > seqlen: |
|
data.append(doc) |
|
doc = [tokenizer.bos_token_id] |
|
doc.extend(sen) |
|
if len(doc) > 1 and len(doc) <= seqlen: |
|
data.append(doc) |
|
return data |
|
|
|
|
|
class LMEvalAdaptor(BaseLM): |
|
def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1): |
|
super().__init__() |
|
|
|
assert isinstance(batch_size, int) |
|
|
|
self.model_name = model_name |
|
self.model = model |
|
self.model.eval() |
|
|
|
self.tokenizer = tokenizer |
|
|
|
self.vocab_size = self.tokenizer.vocab_size |
|
|
|
self._batch_size = batch_size |
|
|
|
self._max_length = max_length |
|
|
|
@property |
|
def eot_token_id(self): |
|
|
|
return self.tokenizer.eos_token_id |
|
|
|
@property |
|
def max_length(self): |
|
if self._max_length != -1: |
|
return self._max_length |
|
if hasattr(self.model.config, "n_ctx"): |
|
return self.model.config.n_ctx |
|
elif hasattr(self.model.config, "max_position_embeddings"): |
|
return self.model.config.max_position_embeddings |
|
elif hasattr(self.model.config, "n_positions"): |
|
return self.model.config.n_positions |
|
elif "bloom" in self.model_name: |
|
return 2048 |
|
elif "llama" in self.model_name: |
|
return 2048 |
|
elif "mpt" in self.model_name: |
|
return 2048 |
|
elif "falcon" in self.model_name: |
|
return 2048 |
|
else: |
|
print(self.model.config) |
|
raise NotImplementedError |
|
|
|
@property |
|
def max_gen_toks(self): |
|
return 256 |
|
|
|
@property |
|
def batch_size(self): |
|
return self._batch_size |
|
|
|
@property |
|
def device(self): |
|
return "cuda" |
|
|
|
def tok_encode(self, string: str, add_special_tokens=True): |
|
return self.tokenizer.encode(string, add_special_tokens=add_special_tokens) |
|
|
|
def tok_decode(self, tokens): |
|
return self.tokenizer.decode(tokens) |
|
|
|
def loglikelihood(self, requests): |
|
new_reqs = [] |
|
for context, continuation in requests: |
|
context, continuation = context.strip(), continuation.strip() |
|
if context == "": |
|
|
|
context_enc = [self.eot_token_id] |
|
else: |
|
context_enc = self.tok_encode(context, add_special_tokens=True) |
|
|
|
continuation_enc = self.tok_encode(continuation, add_special_tokens=False) |
|
|
|
new_reqs.append(((context, continuation), context_enc, continuation_enc)) |
|
|
|
return self._loglikelihood_tokens(new_reqs) |
|
|
|
def _model_call(self, inps): |
|
""" |
|
inps: a torch tensor of shape [batch, sequence] |
|
the size of sequence may vary from call to call |
|
|
|
returns: a torch tensor of shape [batch, sequence, vocab] with the |
|
logits returned from the model |
|
""" |
|
with torch.no_grad(): |
|
out = self.model(inps)[0] |
|
return out |
|
|
|
def _model_generate(self, context, max_length, eos_token_id): |
|
return self.model.generate( |
|
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False |
|
) |