idiomify / explore /explore_bart_logits_shape.py
eubinecto's picture
[#1] Alpha implemented
d8d4c8d
raw
history blame
1.14 kB
from transformers import BartTokenizer, BartForConditionalGeneration
from data import IdiomifyDataModule
CONFIG = {
"literal2idiomatic_ver": "pie_v0",
"batch_size": 20,
"num_workers": 4,
"shuffle": True
}
def main():
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
datamodule = IdiomifyDataModule(CONFIG, tokenizer)
datamodule.prepare_data()
datamodule.setup()
for batch in datamodule.train_dataloader():
srcs, tgts_r, tgts = batch
input_ids, attention_mask = srcs[:, 0], srcs[:, 1] # noqa
decoder_input_ids, decoder_attention_mask = tgts_r[:, 0], tgts_r[:, 1]
outputs = bart(input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask)
logits = outputs[0]
print(logits.shape)
"""
torch.Size([20, 47, 50265])
(N, L, |V|)
"""
break
if __name__ == '__main__':
main()