Spaces:
Runtime error
Runtime error
import itertools | |
from abc import ABC, abstractmethod | |
from typing import List, Optional, Tuple | |
import torch | |
from torch.nn import CrossEntropyLoss | |
from transformers import T5Tokenizer, T5ForConditionalGeneration, GPT2Tokenizer, PreTrainedTokenizer, PreTrainedModel | |
from generation import SelfDebiasingGPT2LMHeadModel | |
class ModelWrapper(ABC): | |
""" | |
This class represents a wrapper for a pretrained language model that provides some high-level functions, including zero-shot | |
classification using cloze questions and the generation of texts with self-debiasing. | |
""" | |
def __init__(self, use_cuda: bool = True): | |
""" | |
:param use_cuda: whether to use CUDA | |
""" | |
self._device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu" | |
self._tokenizer = None # type: Optional[PreTrainedTokenizer] | |
self._model = None # type: Optional[PreTrainedModel] | |
def query_model(self, input_text: str) -> torch.FloatTensor: | |
"""For a given input text, returns the probability distribution over possible next tokens.""" | |
return self.query_model_batch([input_text])[0] | |
def query_model_batch(self, input_texts: List[str]) -> torch.FloatTensor: | |
"""For a batch of input texts, returns the probability distribution over possible next tokens.""" | |
pass | |
def generate(self, input_text: str, **kwargs) -> str: | |
"""Generates a continuation for a given input text.""" | |
pass | |
def generate_self_debiasing(self, input_texts: List[str], debiasing_prefixes: List[str], decay_constant: float = 50, | |
epsilon: float = 0.01, debug: bool = False, **kwargs) -> List[str]: | |
""" | |
Generates continuations for the given input texts with self-debiasing. | |
:param input_texts: the input texts to generate continuations for | |
:param debiasing_prefixes: the debiasing prefixes to be used | |
:param decay_constant: the decay constant (lambda in the paper) | |
:param epsilon: the minimum factor by which each probability is multiplied | |
:param debug: whether to print additional debugging output | |
:param kwargs: further arguments are passed on to the original generate function | |
:return: the list of generated continuations | |
""" | |
pass | |
def compute_loss(self, input_ids: torch.LongTensor, labels: torch.LongTensor) -> torch.Tensor: | |
"""Computes cross-entropy loss for the given input ids and corresponding labels.""" | |
pass | |
def compute_loss_self_debiasing(self, input_ids: torch.Tensor, trg_len: int, debiasing_prefixes: List[str], decay_constant: float = 50, | |
epsilon: float = 0.01, debug: bool = False) -> torch.Tensor: | |
""" | |
Computes cross-entropy loss for the given input ids with self-debiasing. | |
:param input_ids: the input ids | |
:param trg_len: only the last trg_len tokens are considered for computing the loss | |
:param debiasing_prefixes: the debiasing prefixes to be used | |
:param decay_constant: the decay constant (lambda in the paper) | |
:param epsilon: the minimum factor by which each probability is multiplied | |
:param debug: whether to print additional debugging output | |
:return: the cross entropy loss | |
""" | |
pass | |
def get_token_probability_distribution(self, input_texts: List[str], output_choices: List[str]) -> List[List[Tuple[str, float]]]: | |
""" | |
For a batch of input texts, returns the probability distribution over possible next tokens considering only the given list of | |
output choices. | |
:param input_texts: the input texts | |
:param output_choices: the allowed output choices (must correspond to single tokens in the model's vocabulary) | |
:return: a list of lists, where output[i][j] is a (output, probability) tuple for the ith input and jth output choice. | |
""" | |
output_choice_ids = [] | |
kwargs = {'add_prefix_space': True} if isinstance(self, GPT2Wrapper) else {} | |
for word in output_choices: | |
tokens = self._tokenizer.tokenize(word, **kwargs) | |
assert len(tokens) == 1, f"Word {word} consists of multiple tokens: {tokens}" | |
assert tokens[0] not in self._tokenizer.all_special_tokens, f"Word {word} corresponds to a special token: {tokens[0]}" | |
token_id = self._tokenizer.convert_tokens_to_ids(tokens)[0] | |
output_choice_ids.append(token_id) | |
logits = self.query_model_batch(input_texts) | |
result = [] | |
for idx, _ in enumerate(input_texts): | |
output_probabilities = logits[idx][output_choice_ids].softmax(dim=0) | |
choices_with_probabilities = list(zip(output_choices, (prob.item() for prob in output_probabilities))) | |
result.append(choices_with_probabilities) | |
return result | |
class T5Wrapper(ModelWrapper): | |
"""A wrapper for the T5 model""" | |
def __init__(self, model_name: str = "google/t5-v1_1-xl", use_cuda: bool = True): | |
""" | |
:param model_name: the name of the pretrained T5 model (default: "google/t5-v1_1-xl") | |
:param use_cuda: whether to use CUDA | |
""" | |
super().__init__(use_cuda=use_cuda) | |
self._tokenizer = T5Tokenizer.from_pretrained(model_name) | |
self._model = T5ForConditionalGeneration.from_pretrained(model_name) | |
if use_cuda: | |
self._model.parallelize() | |
def query_model_batch(self, input_texts: List[str]): | |
assert all('<extra_id_0>' in input_text for input_text in input_texts) | |
output_texts = ['<extra_id_0>'] * len(input_texts) | |
inputs = self._tokenizer.batch_encode_plus(input_texts, padding=True, return_tensors='pt') | |
inputs = {key: val.to(self._device) for key, val in inputs.items()} | |
output_ids = self._tokenizer.batch_encode_plus(output_texts, return_tensors='pt')['input_ids'].to(self._device) | |
return self._model(labels=output_ids, **inputs)['logits'][:, 1, :] | |
def generate(self, input_text: str, **kwargs): | |
assert '<extra_id_0>' in input_text | |
input_ids = self._tokenizer.encode(input_text, return_tensors='pt').to(self._device) | |
output_ids = self._model.generate(input_ids, **kwargs)[0] | |
return self._tokenizer.decode(output_ids) | |
def generate_self_debiasing(self, input_texts: List[str], debiasing_prefixes: List[str], decay_constant: float = 50, | |
epsilon: float = 0.01, debug: bool = False, **kwargs) -> List[str]: | |
raise NotImplementedError() | |
def compute_loss(self, input_ids: torch.LongTensor, labels: torch.LongTensor) -> torch.Tensor: | |
raise NotImplementedError() | |
def compute_loss_self_debiasing(self, input_ids: torch.Tensor, trg_len: int, debiasing_prefixes: List[str], decay_constant: float = 50, | |
epsilon: float = 0.01, debug: bool = False) -> torch.Tensor: | |
raise NotImplementedError() | |
class GPT2Wrapper(ModelWrapper): | |
def __init__(self, model_name: str = "gpt2-xl", use_cuda: bool = True): | |
""" | |
:param model_name: the name of the pretrained GPT2 model (default: "gpt2-xl") | |
:param use_cuda: whether to use CUDA | |
""" | |
super().__init__(use_cuda=use_cuda) | |
self._tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
self._model = SelfDebiasingGPT2LMHeadModel.from_pretrained(model_name) # type: SelfDebiasingGPT2LMHeadModel | |
if use_cuda: | |
self._model.parallelize() | |
self._tokenizer.pad_token = self._tokenizer.eos_token | |
self._model.config.pad_token_id = self._tokenizer.eos_token_id | |
def query_model_batch(self, input_texts: List[str]): | |
inputs = self._tokenizer.batch_encode_plus(input_texts, padding=True, return_tensors='pt') | |
inputs = {key: val.to(self._device) for key, val in inputs.items()} | |
output_indices = inputs['attention_mask'].sum(dim=1) - 1 | |
output = self._model(**inputs)['logits'] | |
return torch.stack([output[example_idx, last_word_idx, :] for example_idx, last_word_idx in enumerate(output_indices)]) | |
def generate(self, input_text: str, **kwargs): | |
input_ids = self._tokenizer.encode(input_text, return_tensors='pt').to(self._device) | |
output_ids = self._model.generate(input_ids, **kwargs)[0] | |
return self._tokenizer.decode(output_ids) | |
def generate_self_debiasing(self, input_texts: List[str], debiasing_prefixes: List[str], decay_constant: float = 50, | |
epsilon: float = 0.01, debug: bool = False, min_length: int = None, max_length: int = None, | |
**kwargs) -> List[str]: | |
self._model.init_logits_processor(num_debiasing_prefixes=len(debiasing_prefixes), decay_constant=decay_constant, epsilon=epsilon, | |
debug=debug, tokenizer=self._tokenizer) | |
inputs = input_texts.copy() | |
for debiasing_prefix in debiasing_prefixes: | |
for input_text in input_texts: | |
inputs += [debiasing_prefix + input_text] | |
inputs = self._tokenizer.batch_encode_plus(inputs, padding=True, return_tensors='pt') | |
inputs['attention_mask'] = torch.flip(inputs['attention_mask'], dims=[1]) | |
shifts = inputs['attention_mask'].shape[-1] - inputs['attention_mask'].sum(dim=-1) | |
for batch_idx in range(inputs['input_ids'].shape[0]): | |
inputs['input_ids'][batch_idx] = inputs['input_ids'][batch_idx].roll(shifts[batch_idx].item()) | |
inputs = {k: v.to(self._device) for k, v in inputs.items()} | |
input_length = inputs['input_ids'].shape[1] | |
if min_length is not None: | |
min_length = min_length + input_length | |
if max_length is not None: | |
max_length = max_length + input_length | |
output_ids = self._model.generate(**inputs, min_length=min_length, max_length=max_length, **kwargs) | |
batch_size = output_ids.shape[0] // (1 + len(debiasing_prefixes)) | |
output_ids = output_ids[:batch_size, inputs['input_ids'].shape[1]:] | |
return self._tokenizer.batch_decode(output_ids) | |
def compute_loss(self, input_ids: torch.LongTensor, labels: torch.LongTensor) -> torch.Tensor: | |
outputs = self._model(input_ids, labels=labels) | |
lm_logits = outputs[1] | |
# Shift so that tokens < n predict n | |
shift_logits = lm_logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# Flatten the tokens | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
return loss | |
def compute_loss_self_debiasing(self, input_ids: torch.Tensor, trg_len: int, debiasing_prefixes: List[str], decay_constant: float = 50, | |
epsilon: float = 0.01, debug: bool = False) -> torch.Tensor: | |
self._model.init_logits_processor(num_debiasing_prefixes=len(debiasing_prefixes), decay_constant=decay_constant, epsilon=epsilon, | |
debug=debug, tokenizer=self._tokenizer) | |
input_prefixes = [''] + debiasing_prefixes | |
input_prefixes = self._tokenizer.batch_encode_plus(input_prefixes, padding=True, return_tensors='pt') | |
input_prefixes['attention_mask'] = torch.flip(input_prefixes['attention_mask'], dims=[1]) | |
shifts = input_prefixes['attention_mask'].shape[-1] - input_prefixes['attention_mask'].sum(dim=-1) | |
for batch_idx in range(input_prefixes['input_ids'].shape[0]): | |
input_prefixes['input_ids'][batch_idx] = input_prefixes['input_ids'][batch_idx].roll(shifts[batch_idx].item()) | |
input_prefixes = {k: v.to(self._device) for k, v in input_prefixes.items()} | |
input_ids_repeated = input_ids.repeat(len(debiasing_prefixes) + 1, 1) | |
attention_mask = torch.ones_like(input_ids_repeated) | |
attention_mask = torch.cat([input_prefixes['attention_mask'], attention_mask], dim=-1) | |
input_ids_repeated = torch.cat([input_prefixes['input_ids'], input_ids_repeated], dim=-1) | |
target_ids = input_ids_repeated.clone() | |
trg_len += shifts[0] | |
target_ids[:, :-trg_len] = -100 | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
outputs = self._model(input_ids=input_ids_repeated, attention_mask=attention_mask, position_ids=position_ids, labels=target_ids) | |
lm_logits = outputs[1] | |
for idx in range(lm_logits.shape[1]): | |
lm_logits[:, idx, :] = self._model.logits_processor(input_ids=None, scores=lm_logits[:, idx, :]) | |
batch_size = lm_logits.shape[0] // (1 + len(debiasing_prefixes)) | |
lm_logits = lm_logits[:batch_size, shifts[0]:, :] | |
target_ids = target_ids[:batch_size, shifts[0]:] | |
# Shift so that tokens < n predict n | |
shift_logits = lm_logits[..., :-1, :].contiguous() | |
shift_labels = target_ids[..., 1:].contiguous() | |
# Flatten the tokens | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
return loss | |