|
|
''' |
|
|
This file is inspired by the code from https://github.com/ML-GSAI/SMDM |
|
|
''' |
|
|
import accelerate |
|
|
import torch |
|
|
import re |
|
|
from pathlib import Path |
|
|
import random |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
from datasets import Dataset |
|
|
from lm_eval.__main__ import cli_evaluate |
|
|
from lm_eval.api.instance import Instance |
|
|
from lm_eval.api.model import LM |
|
|
from lm_eval.api.registry import register_model |
|
|
from tqdm import tqdm |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from generate import generate |
|
|
|
|
|
|
|
|
def set_seed(seed): |
|
|
torch.manual_seed(seed) |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
|
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
@register_model("llada_dist") |
|
|
class LLaDAEvalHarness(LM): |
|
|
def __init__( |
|
|
self, |
|
|
model_path='', |
|
|
mask_id=126336, |
|
|
max_length=4096, |
|
|
batch_size=32, |
|
|
mc_num=128, |
|
|
is_check_greedy=True, |
|
|
cfg=0., |
|
|
steps=1024, |
|
|
gen_length=1024, |
|
|
block_length=1024, |
|
|
remasking='low_confidence', |
|
|
device="cuda", |
|
|
**kwargs, |
|
|
): |
|
|
''' |
|
|
Args: |
|
|
model_path: LLaDA-8B-Base model path. |
|
|
mask_id: The token id of [MASK] is 126336. |
|
|
max_length: the max sequence length. |
|
|
batch_size: mini batch size. |
|
|
mc_num: Monte Carlo estimation iterations |
|
|
is_check_greedy: For certain metrics like LAMBADA, the evaluation requires the model to verify whether the answer |
|
|
is generated through greedy sampling conditioned on the prompt (note that this differs from conditional |
|
|
generation). We implement this verification through the suffix_greedy_prediction() function, which |
|
|
returns a True/False judgment used for accuracy calculation. |
|
|
When is_check_greedy is set to True, the lm-evaluation-harness library automatically invokes this function. |
|
|
However, since none of the metrics in the LLaDA paper (https://arxiv.org/abs/2502.09992) require this functionality, |
|
|
we recommend setting is_check_greedy to False. This configuration causes suffix_greedy_prediction() to return False |
|
|
by default, significantly accelerating the evaluation process. |
|
|
cfg_scale: Unsupervised classifier-free guidance scale. |
|
|
''' |
|
|
super().__init__() |
|
|
|
|
|
accelerator = accelerate.Accelerator() |
|
|
if accelerator.num_processes > 1: |
|
|
self.accelerator = accelerator |
|
|
else: |
|
|
self.accelerator = None |
|
|
|
|
|
model_kwargs = {} |
|
|
if self.accelerator is not None: |
|
|
model_kwargs.update({'device_map': {'': f'{self.accelerator.device}'}}) |
|
|
|
|
|
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, **model_kwargs) |
|
|
self.model.eval() |
|
|
|
|
|
self.device = torch.device(device) |
|
|
if self.accelerator is not None: |
|
|
self.model = self.accelerator.prepare(self.model) |
|
|
self.device = torch.device(f'{self.accelerator.device}') |
|
|
self._rank = self.accelerator.local_process_index |
|
|
self._world_size = self.accelerator.num_processes |
|
|
else: |
|
|
self.model = self.model.to(device) |
|
|
|
|
|
self.mask_id = mask_id |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
|
|
self.mc_num = mc_num |
|
|
self.batch_size = int(batch_size) |
|
|
assert mc_num % self.batch_size == 0 |
|
|
self.sampling_eps = 0. |
|
|
self.max_length = max_length |
|
|
self.is_check_greedy = is_check_greedy |
|
|
|
|
|
self.cfg = cfg |
|
|
self.steps = steps |
|
|
self.gen_length = gen_length |
|
|
self.block_length = block_length |
|
|
self.remasking = remasking |
|
|
@property |
|
|
def rank(self): |
|
|
return self._rank |
|
|
|
|
|
@property |
|
|
def world_size(self): |
|
|
return self._world_size |
|
|
|
|
|
def _forward_process(self, batch, prompt_index): |
|
|
b, l = batch.shape |
|
|
|
|
|
target_len = (l - prompt_index.sum()).item() |
|
|
k = torch.randint(1, target_len + 1, (), device=batch.device) |
|
|
|
|
|
x = torch.round(torch.linspace(float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device)).long() |
|
|
x = ((x - 1) % target_len) + 1 |
|
|
assert x.min() >= 1 and x.max() <= target_len |
|
|
|
|
|
indices = torch.arange(target_len, device=batch.device).repeat(b, 1) |
|
|
is_mask = indices < x.unsqueeze(1) |
|
|
|
|
|
for i in range(b): |
|
|
is_mask[i] = is_mask[i][torch.randperm(target_len)] |
|
|
|
|
|
is_mask = torch.cat((torch.zeros(b, prompt_index.sum(), dtype=torch.bool, device=batch.device), is_mask), dim=1) |
|
|
|
|
|
noisy_batch = torch.where(is_mask, self.mask_id, batch) |
|
|
|
|
|
return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l) |
|
|
|
|
|
@torch.no_grad() |
|
|
def get_logits(self, batch, prompt_index): |
|
|
if self.cfg > 0.: |
|
|
assert len(prompt_index) == batch.shape[1] |
|
|
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1) |
|
|
un_batch = batch.clone() |
|
|
un_batch[prompt_index] = self.mask_id |
|
|
batch = torch.cat([batch, un_batch]) |
|
|
|
|
|
logits = self.model(batch).logits |
|
|
|
|
|
if self.cfg > 0.: |
|
|
logits, un_logits = torch.chunk(logits, 2, dim=0) |
|
|
logits = un_logits + (self.cfg + 1) * (logits - un_logits) |
|
|
return logits[:, :batch.shape[1]] |
|
|
|
|
|
@torch.no_grad() |
|
|
def get_loglikelihood(self, prefix, target): |
|
|
seq = torch.concatenate([prefix, target])[None, :] |
|
|
seq = seq.repeat((self.batch_size, 1)).to(self.device) |
|
|
|
|
|
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) |
|
|
|
|
|
loss_acc = [] |
|
|
for _ in range(self.mc_num // self.batch_size): |
|
|
perturbed_seq, p_mask = self._forward_process(seq, prompt_index) |
|
|
|
|
|
mask_indices = perturbed_seq == self.mask_id |
|
|
|
|
|
logits = self.get_logits(perturbed_seq, prompt_index) |
|
|
|
|
|
loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices] |
|
|
loss = loss.sum() / self.batch_size |
|
|
loss_acc.append(loss.item()) |
|
|
|
|
|
return - sum(loss_acc) / len(loss_acc) |
|
|
|
|
|
@torch.no_grad() |
|
|
def suffix_greedy_prediction(self, prefix, target): |
|
|
if not self.is_check_greedy: |
|
|
return False |
|
|
|
|
|
seq = torch.full((1, len(prefix) + len(target)), self.mask_id, device=self.device) |
|
|
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) |
|
|
prefix, target = prefix.to(self.device), target.to(self.device) |
|
|
seq[0, :len(prefix)] = prefix |
|
|
|
|
|
for i in range(len(target)): |
|
|
mask_index = (seq == self.mask_id) |
|
|
logits = self.get_logits(seq, prompt_index)[mask_index] |
|
|
x0 = torch.argmax(logits, dim=-1) |
|
|
|
|
|
p = torch.softmax(logits.to(torch.float32), dim=-1) |
|
|
confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(dim=-1) |
|
|
_, index = torch.sort(confidence, descending=True) |
|
|
x0[index[1:]] = self.mask_id |
|
|
seq[mask_index] = x0.clone() |
|
|
correct = target == seq[0, len(prefix):] |
|
|
correct = torch.all(correct) |
|
|
return correct |
|
|
|
|
|
def _encode_pair(self, context, continuation): |
|
|
n_spaces = len(context) - len(context.rstrip()) |
|
|
if n_spaces > 0: |
|
|
continuation = context[-n_spaces:] + continuation |
|
|
context = context[:-n_spaces] |
|
|
|
|
|
whole_enc = self.tokenizer(context + continuation)["input_ids"] |
|
|
context_enc = self.tokenizer(context)["input_ids"] |
|
|
|
|
|
context_enc_len = len(context_enc) |
|
|
continuation_enc = whole_enc[context_enc_len:] |
|
|
|
|
|
return context_enc, continuation_enc |
|
|
|
|
|
def loglikelihood(self, requests): |
|
|
def _tokenize(e): |
|
|
prefix, target = self._encode_pair(e["prefix"], e["target"]) |
|
|
return { |
|
|
"prefix_text": e["prefix"], |
|
|
"target_text": e["target"], |
|
|
"prefix": prefix, |
|
|
"target": target, |
|
|
} |
|
|
|
|
|
ds = [] |
|
|
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests] |
|
|
ds = Dataset.from_list(ds) |
|
|
ds = ds.map(_tokenize) |
|
|
ds = ds.with_format("torch") |
|
|
prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds] |
|
|
|
|
|
assert max(prompt_len) <= 4096 |
|
|
|
|
|
out = [] |
|
|
with torch.no_grad(): |
|
|
for elem in tqdm(ds, desc="Computing likelihood..."): |
|
|
prefix = elem["prefix"] |
|
|
target = elem["target"] |
|
|
|
|
|
ll = self.get_loglikelihood(prefix, target) |
|
|
|
|
|
is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target) |
|
|
|
|
|
out.append((ll, 1.0 if is_target_greedy_dec else 0.0)) |
|
|
torch.cuda.empty_cache() |
|
|
return out |
|
|
|
|
|
def loglikelihood_rolling(self, requests): |
|
|
raise NotImplementedError |
|
|
|
|
|
def generate_until(self, requests: list[Instance]): |
|
|
def _tokenize(e): |
|
|
return { |
|
|
"question": self.tokenizer(e["question"])["input_ids"], |
|
|
"question_text": e["question"], |
|
|
"until": e["until"], |
|
|
} |
|
|
|
|
|
ds = [{"question": req.args[0], "until": req.args[1]['until']} for req in requests] |
|
|
ds = Dataset.from_list(ds) |
|
|
ds = ds.map(_tokenize) |
|
|
ds = ds.with_format("torch") |
|
|
|
|
|
out = [] |
|
|
for elem in tqdm(ds, desc="Generating..."): |
|
|
prompt = elem["question"].unsqueeze(0).to(self.device) |
|
|
stop_tokens = elem["until"] |
|
|
|
|
|
generated_answer = generate(self.model, prompt, steps=self.steps, gen_length=self.gen_length, block_length=self.block_length, |
|
|
temperature=0, cfg_scale=self.cfg, remasking=self.remasking, mask_id=self.mask_id) |
|
|
|
|
|
generated_answer = self.tokenizer.decode(generated_answer[0][prompt.shape[1]:], skip_special_tokens=False) |
|
|
for stop_seq in stop_tokens: |
|
|
if stop_seq in generated_answer: |
|
|
generated_answer = generated_answer.split(stop_seq)[0] |
|
|
|
|
|
|
|
|
generated_answer_ids = self.tokenizer(generated_answer)["input_ids"] |
|
|
generated_answer = self.tokenizer.decode(generated_answer_ids, skip_special_tokens=True) |
|
|
out.append(generated_answer) |
|
|
|
|
|
self.accelerator.wait_for_everyone() |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
set_seed(1234) |
|
|
cli_evaluate() |
|
|
|