kkawamu1
Update run_eval.py
1f6c0e9 unverified
# Copyright 2022 Ken Kawamura
# Copyright BigScience, The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file has been modifed from the original version on https://github.com/bigscience-workshop/t-zero/blob/master/evaluation/run_eval.py
import torch
from accelerate import Accelerator
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoTokenizer, set_seed)
def multi_inference_rank_eval(model_name_or_path, auto_class, ex_answer_choices, context):
accelerator = Accelerator()
set_seed(42)
model_name = model_name_or_path
if auto_class == 'Seq2SeqLM':
# e.g. 'google/t5-small-lm-adapt'
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
else:
# e.g. 'gpt2'
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
if tokenizer.pad_token is None:
for token in [tokenizer.eos_token, tokenizer.bos_token, tokenizer.sep_token]:
if token is not None:
tokenizer.pad_token = token
if tokenizer.pad_token is None:
raise ValueError("Please define a pad token id.")
padding = False
if auto_class == 'Seq2SeqLM':
def preprocess_function(context, ex_answer_choices):
input_texts = []
answer_choices_texts = []
input_texts.append(context)
answer_choices_texts.append(
[' ' + ans for ans in ex_answer_choices])
tokenized_inputs = tokenizer(
input_texts,
padding=padding,
max_length=1024,
truncation=True,
add_special_tokens=False,
)
tokenized_targets = [
tokenizer(
ans_choi,
padding=True,
max_length=256,
truncation=True,
)
for ans_choi in answer_choices_texts
]
features = {
k: [
[elem for _ in range(
len(tokenized_targets[idx]["input_ids"]))]
for idx, elem in enumerate(v)
]
for k, v in tokenized_inputs.items()
}
features["labels"] = [
tokenized_targets[0]["input_ids"]
]
features["labels_attention_mask"] = [
tokenized_targets[0]["attention_mask"]
]
return features
else:
def preprocess_function(context, ex_answer_choices):
input_texts = []
answer_choices_texts = []
input_texts.append(context)
answer_choices_texts.append(
[' ' + ans for ans in ex_answer_choices])
tokenized_inputs = tokenizer(
input_texts,
padding=padding,
max_length=1024,
truncation=True,
add_special_tokens=False,
)
tokenized_targets = [
tokenizer(
ans_choi,
padding=True,
max_length=256,
truncation=True,
)
for ans_choi in answer_choices_texts
]
features = {
k: [
[elem for _ in range(
len(tokenized_targets[idx]["input_ids"]))]
for idx, elem in enumerate(v)
]
for k, v in tokenized_inputs.items()
}
features["labels"] = [
tokenized_targets[0]["input_ids"]
]
features["labels_attention_mask"] = [
tokenized_targets[0]["attention_mask"]
]
features["labels"] = [
[features["input_ids"][0][i][1:] + tokenized_targets[0]["input_ids"][i]
for i in range(len(tokenized_targets[0]["input_ids"]))]
]
features["input_ids"] = [
[features["input_ids"][0][i] + tokenized_targets[0]["input_ids"][i][:-1]
for i in range(len(tokenized_targets[0]["input_ids"]))]
]
features["labels_attention_mask"] = [
[[0] * (len(features["attention_mask"][0][i])-1) + tokenized_targets[0]
["attention_mask"][i] for i in range(len(tokenized_targets[0]["input_ids"]))]
]
features["attention_mask"] = [
[features["attention_mask"][0][i] + tokenized_targets[0]["attention_mask"][i][:-1]
for i in range(len(tokenized_targets[0]["input_ids"]))]
]
return features
device = accelerator.device
model.to(device)
batch = preprocess_function(context, ex_answer_choices)
batch = {
k: torch.tensor(batch[k][0]).to(device)
for k in batch.keys()
}
model.eval()
with torch.no_grad():
model_inputs = {
k: batch[k]
for k in (["input_ids", "attention_mask", "labels"] if auto_class == 'Seq2SeqLM' else ["input_ids", "attention_mask"])
}
logits = model(**model_inputs).logits
masked_log_probs = batch["labels_attention_mask"].unsqueeze(
-1) * torch.log_softmax(logits, dim=-1)
seq_token_log_probs = torch.gather(
masked_log_probs, -1, batch["labels"].unsqueeze(-1))
seq_log_prob = seq_token_log_probs.squeeze(dim=-1).sum(dim=-1)
seq_log_prob = seq_log_prob.view(1, -1)
predictions = seq_log_prob.argmax(dim=-1)
predictions = accelerator.gather(predictions)
return predictions.item()
if __name__ == "__main__":
multi_inference_rank_eval('google/t5-small-lm-adapt', 'Seq2SeqLM',
['True', 'False', 'True', 'Ken'], 'I am Ken. True or False')
# multi_inference_rank_eval('gpt2', 'CausalLM', ['True', 'False', 'True', 'Ken'], 'I am Ken. True or False')