Spaces:
Build error
Build error
from unittest import TestCase | |
import torch | |
from transformers import AutoConfig, AutoModelWithLMHead, AutoTokenizer | |
import autoprompt.create_trigger as ct | |
def _load(model_name): | |
config = AutoConfig.from_pretrained('bert-base-cased') | |
model = AutoModelWithLMHead.from_pretrained('bert-base-cased') | |
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased') | |
return config, model, tokenizer | |
class TestGetEmbeddings(TestCase): | |
def test_bert(self): | |
model_name = 'bert-base-cased' | |
config, model, tokenizer = _load(model_name) | |
embeddings = ct.get_embeddings(model, config) | |
self.assertEqual(embeddings.weight.shape[0], config.vocab_size) | |
def test_roberta(self): | |
model_name = 'roberta-base' | |
config, model, tokenizer = _load(model_name) | |
embeddings = ct.get_embeddings(model, config) | |
self.assertEqual(embeddings.weight.shape[0], config.vocab_size) | |
class TestGradientStorage(TestCase): | |
def test_gradient_storage(self): | |
num_embeddings = 3 | |
embedding_dim = 4 | |
embeddings = torch.nn.Embedding(num_embeddings, embedding_dim) | |
embedding_storage = ct.GradientStorage(embeddings) | |
inputs = torch.tensor([0, 1, 2, 1]) | |
outputs = embeddings(inputs) | |
outputs.retain_grad() | |
loss = outputs.sum() | |
loss.backward() | |
assert torch.equal(outputs.grad, embedding_storage.get()) | |
def test_replace_trigger_tokens(): | |
model_inputs = { | |
'input_ids': torch.tensor([ | |
[1, 2, 3, 4], | |
[1, 1, 1, 0] | |
]) | |
} | |
trigger_ids = torch.tensor([[5, 6]]) | |
trigger_mask = torch.tensor([ | |
[True, True, False, False], | |
[False, True, False, True] | |
]) | |
replaced = ct.replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask) | |
expected = torch.tensor([ | |
[5, 6, 3, 4], | |
[1, 5, 1, 6] | |
]) | |
assert torch.equal(expected, replaced['input_ids']) | |