Spaces:
Build error
Build error
import csv | |
import copy | |
import json | |
import logging | |
import random | |
from collections import defaultdict | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
MAX_CONTEXT_LEN = 50 | |
logger = logging.getLogger(__name__) | |
def pad_squeeze_sequence(sequence, *args, **kwargs): | |
"""Squeezes fake batch dimension added by tokenizer before padding sequence.""" | |
return pad_sequence([x.squeeze(0) for x in sequence], *args, **kwargs) | |
class OutputStorage: | |
""" | |
This object stores the intermediate gradients of the output a the given PyTorch module, which | |
otherwise might not be retained. | |
""" | |
def __init__(self, module): | |
self._stored_output = None | |
module.register_forward_hook(self.hook) | |
def hook(self, module, input, output): | |
self._stored_output = output | |
def get(self): | |
return self._stored_output | |
class ExponentialMovingAverage: | |
def __init__(self, weight=0.3): | |
self._weight = weight | |
self.reset() | |
def update(self, x): | |
self._x += x | |
self._i += 1 | |
def reset(self): | |
self._x = 0 | |
self._i = 0 | |
def get_metric(self): | |
return self._x / (self._i + 1e-13) | |
class Collator: | |
""" | |
Collates transformer outputs. | |
""" | |
def __init__(self, pad_token_id=0): | |
self._pad_token_id = pad_token_id | |
def __call__(self, features): | |
# Separate the list of inputs and labels | |
model_inputs, labels = list(zip(*features)) | |
# Assume that all inputs have the same keys as the first | |
proto_input = model_inputs[0] | |
keys = list(proto_input.keys()) | |
padded_inputs = {} | |
for key in keys: | |
if key == 'input_ids': | |
padding_value = self._pad_token_id | |
else: | |
padding_value = 0 | |
# NOTE: We need to squeeze to get rid of fake batch dim. | |
sequence = [x[key] for x in model_inputs] | |
padded = pad_squeeze_sequence(sequence, batch_first=True, padding_value=padding_value) | |
padded_inputs[key] = padded | |
labels = pad_squeeze_sequence(labels, batch_first=True, padding_value=0) | |
return padded_inputs, labels | |
def encode_label(tokenizer, label, tokenize=False): | |
""" | |
Helper function for encoding labels. Deals with the subtleties of handling multiple tokens. | |
""" | |
if isinstance(label, str): | |
if tokenize: | |
# Ensure label is properly tokenized, and only retain first token | |
# if it gets split into multiple tokens. TODO: Make sure this is | |
# desired behavior. | |
tokens = tokenizer.tokenize(label) | |
if len(tokens) > 1: | |
raise ValueError(f'Label "{label}" gets mapped to multiple tokens.') | |
if tokens[0] == tokenizer.unk_token: | |
raise ValueError(f'Label "{label}" gets mapped to unk.') | |
label = tokens[0] | |
encoded = torch.tensor(tokenizer.convert_tokens_to_ids([label])).unsqueeze(0) | |
elif isinstance(label, list): | |
encoded = torch.tensor(tokenizer.convert_tokens_to_ids(label)).unsqueeze(0) | |
elif isinstance(label, int): | |
encoded = torch.tensor([[label]]) | |
return encoded | |
class TriggerTemplatizer: | |
""" | |
An object to facilitate creating transformers-friendly triggers inputs from a template. | |
Parameters | |
========== | |
template : str | |
The template string, comprised of the following tokens: | |
[T] to mark a trigger placeholder. | |
[P] to mark a prediction placeholder. | |
{fields} arbitrary fields instantiated from the dataset instances. | |
For example a NLI template might look like: | |
"[T] [T] [T] {premise} [P] {hypothesis}" | |
tokenizer : PretrainedTokenizer | |
A HuggingFace tokenizer. Must have special trigger and predict tokens. | |
add_special_tokens : bool | |
Whether or not to add special tokens when encoding. Default: False. | |
""" | |
def __init__(self, | |
template, | |
config, | |
tokenizer, | |
label_field='label', | |
label_map=None, | |
tokenize_labels=False, | |
add_special_tokens=False, | |
use_ctx=False): | |
if not hasattr(tokenizer, 'predict_token') or \ | |
not hasattr(tokenizer, 'trigger_token'): | |
raise ValueError( | |
'Tokenizer missing special trigger and predict tokens in vocab.' | |
'Use `utils.add_special_tokens` to add them.' | |
) | |
self._template = template | |
self._config = config | |
self._tokenizer = tokenizer | |
self._label_field = label_field | |
self._label_map = label_map | |
self._tokenize_labels = tokenize_labels | |
self._add_special_tokens = add_special_tokens | |
self._use_ctx = use_ctx | |
def num_trigger_tokens(self): | |
return sum(token == '[T]' for token in self._template.split()) | |
def __call__(self, format_kwargs): | |
# Format the template string | |
format_kwargs = format_kwargs.copy() | |
label = format_kwargs.pop(self._label_field) | |
text = self._template.format(**format_kwargs) | |
if label is None: | |
raise Exception(f'Bad data: {text}') | |
# Have the tokenizer encode the text and process the output to: | |
# - Create a trigger and predict mask | |
# - Replace the predict token with a mask token | |
model_inputs = self._tokenizer.encode_plus( | |
text, | |
add_special_tokens=self._add_special_tokens, | |
return_tensors='pt' | |
) | |
input_ids = model_inputs['input_ids'] | |
trigger_mask = input_ids.eq(self._tokenizer.trigger_token_id) | |
predict_mask = input_ids.eq(self._tokenizer.predict_token_id) | |
input_ids[predict_mask] = self._tokenizer.mask_token_id | |
model_inputs['trigger_mask'] = trigger_mask | |
model_inputs['predict_mask'] = predict_mask | |
# For relation extraction with BERT, update token_type_ids to reflect the two different sequences | |
if self._use_ctx and self._config.model_type == 'bert': | |
sep_token_indices = (input_ids.squeeze(0) == self._tokenizer.convert_tokens_to_ids(self._tokenizer.sep_token)).nonzero().flatten() | |
sequence_b_indices = torch.arange(sep_token_indices[0], sep_token_indices[1] + 1).long().unsqueeze(0) | |
model_inputs['token_type_ids'].scatter_(1, sequence_b_indices, 1) | |
# Encode the label(s) | |
if self._label_map is not None: | |
label = self._label_map[label] | |
label_id = encode_label( | |
tokenizer=self._tokenizer, | |
label=label, | |
tokenize=self._tokenize_labels | |
) | |
return model_inputs, label_id | |
def add_task_specific_tokens(tokenizer): | |
tokenizer.add_special_tokens({ | |
'additional_special_tokens': ['[T]', '[P]', '[Y]'] | |
}) | |
tokenizer.trigger_token = '[T]' | |
tokenizer.trigger_token_id = tokenizer.convert_tokens_to_ids('[T]') | |
tokenizer.predict_token = '[P]' | |
tokenizer.predict_token_id = tokenizer.convert_tokens_to_ids('[P]') | |
# NOTE: BERT and RoBERTa tokenizers work properly if [X] is not a special token... | |
# tokenizer.lama_x = '[X]' | |
# tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[X]') | |
tokenizer.lama_y = '[Y]' | |
tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[Y]') | |
def load_tsv(fname): | |
with open(fname, 'r') as f: | |
reader = csv.DictReader(f, delimiter='\t') | |
for row in reader: | |
yield row | |
def load_jsonl(fname): | |
with open(fname, 'r') as f: | |
for line in f: | |
yield json.loads(line) | |
LOADERS = { | |
'.tsv': load_tsv, | |
'.jsonl': load_jsonl | |
} | |
def load_trigger_dataset(fname, templatizer, use_ctx, limit=None): | |
loader = LOADERS[fname.suffix] | |
instances = [] | |
for x in loader(fname): | |
try: | |
if use_ctx: | |
# For relation extraction, skip facts that don't have context sentence | |
if 'evidences' not in x: | |
logger.warning('Skipping RE sample because it lacks context sentences: {}'.format(x)) | |
continue | |
evidences = x['evidences'] | |
# Randomly pick a context sentence | |
obj_surface, masked_sent = random.choice([(evidence['obj_surface'], evidence['masked_sentence']) for evidence in evidences]) | |
words = masked_sent.split() | |
if len(words) > MAX_CONTEXT_LEN: | |
# If the masked sentence is too long, use the first X tokens. For training we want to keep as many samples as we can. | |
masked_sent = ' '.join(words[:MAX_CONTEXT_LEN]) | |
# If truncated context sentence still has MASK, we need to replace it with object surface | |
# We explicitly use [MASK] because all TREx fact's context sentences use it | |
context = masked_sent.replace('[MASK]', obj_surface) | |
x['context'] = context | |
model_inputs, label_id = templatizer(x) | |
else: | |
model_inputs, label_id = templatizer(x) | |
except ValueError as e: | |
logger.warning('Encountered error "%s" when processing "%s". Skipping.', e, x) | |
continue | |
else: | |
instances.append((model_inputs, label_id)) | |
if limit: | |
return random.sample(instances, limit) | |
else: | |
return instances | |
def load_augmented_trigger_dataset(fname, templatizer, limit=None): | |
loader = LOADERS[fname.suffix] | |
instances = [] | |
# For augmented relation extraction, we need to replace obj_label with another obj_label, and replace obj_surface with a surface form of the new obj_label | |
unique_objs_dict = defaultdict(list) | |
# Also for augmented relation extraction, we need to accumulate all facts and process them afterwards | |
facts = [] | |
for x in loader(fname): | |
try: | |
sub_label = x['sub_label'] | |
obj_label = x['obj_label'] | |
# For relation extraction, skip facts that don't have context sentence | |
if 'evidences' not in x: | |
logger.warning('Skipping RE sample because it lacks context sentences: {}'.format(x)) | |
continue | |
evidences = x['evidences'] | |
# Gather all UNIQUE objects and their surface forms if its augmented relation extraction | |
for evidence in evidences: | |
obj_surface = evidence['obj_surface'] | |
masked_sent = evidence['masked_sentence'] | |
unique_objs_dict[obj_label].append(obj_surface) | |
# Randomly pick a context sentence | |
obj_surface, masked_sent = random.choice([(evidence['obj_surface'], evidence['masked_sentence']) for evidence in evidences]) | |
words = masked_sent.split() | |
if len(words) > MAX_CONTEXT_LEN: | |
# If the masked sentence is too long, use the first X tokens. For training we want to keep as many samples as we can. | |
masked_sent = ' '.join(words[:MAX_CONTEXT_LEN]) | |
x['context'] = masked_sent | |
facts.append(x) | |
except ValueError as e: | |
logger.warning('Encountered error "%s" when processing "%s". Skipping.', e, x) | |
# Go through all facts and replace each object with a new one. Also insert the new object (surface form) into the masked sentence | |
synth_facts = [] | |
for fact in facts: | |
sub_label = fact['sub_label'] | |
obj_label = fact['obj_label'] | |
masked_sent = fact['context'] | |
# print('Original fact: ({}, {}, {})'.format(sub_label, obj_label, masked_sent)) | |
synth_obj_label = random.choice([x for x in unique_objs_dict.keys() if x != obj_label]) | |
synth_obj_surface = random.choice(unique_objs_dict[synth_obj_label]) | |
synth_ctx = masked_sent.replace('[MASK]', synth_obj_surface) | |
# print('Synthetic fact: ({}, {}, {})\n'.format(sub_label, synth_obj_label, synth_ctx)) | |
# Reassign the labels and context sentence | |
synth_fact = copy.deepcopy(fact) | |
synth_fact['sub_label'] = sub_label | |
synth_fact['obj_label'] = synth_obj_label | |
synth_fact['context'] = synth_ctx | |
synth_facts.append(synth_fact) | |
# Go through facts, templatize each one, then append them to instances | |
for fact in synth_facts: | |
model_inputs, label_id = templatizer(fact) | |
instances.append((model_inputs, label_id)) | |
if limit: | |
return random.sample(instances, limit) | |
else: | |
return instances | |
def load_classification_dataset( | |
fname, | |
tokenizer, | |
input_field_a, | |
input_field_b=None, | |
label_field='label', | |
label_map=None, | |
limit=None | |
): | |
""" | |
Loads a dataset for classification | |
Parameters | |
========== | |
tokenizer : transformers.PretrainedTokenizer | |
Maps text to id tensors. | |
sentence1 : | |
""" | |
instances = [] | |
label_map = label_map or {} | |
loader = LOADERS[fname.suffix] | |
for instance in loader(fname): | |
logger.debug(instance) | |
model_inputs = tokenizer.encode_plus( | |
instance[input_field_a], | |
instance[input_field_b] if input_field_b else None, | |
add_special_tokens=True, | |
# add_prefix_space=True, | |
return_tensors='pt' | |
) | |
logger.debug(model_inputs) | |
label = instance[label_field] | |
if label not in label_map: | |
label_map[label] = len(label_map) | |
label_id = label_map[label] | |
label_id = torch.tensor([[label_id]]) # To make collator expectation | |
logger.debug(f'Label id: {label_id}') | |
instances.append((model_inputs, label_id)) | |
if limit: | |
instances = random.sample(instances, limit) | |
return instances, label_map | |