File size: 1,149 Bytes
d8d4c8d 642d911 d8d4c8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from transformers import BartTokenizer, BartForConditionalGeneration
from datamodules import IdiomifyDataModule
CONFIG = {
"literal2idiomatic_ver": "pie_v0",
"batch_size": 20,
"num_workers": 4,
"shuffle": True
}
def main():
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
datamodule = IdiomifyDataModule(CONFIG, tokenizer)
datamodule.prepare_data()
datamodule.setup()
for batch in datamodule.train_dataloader():
srcs, tgts_r, tgts = batch
input_ids, attention_mask = srcs[:, 0], srcs[:, 1] # noqa
decoder_input_ids, decoder_attention_mask = tgts_r[:, 0], tgts_r[:, 1]
outputs = bart(input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask)
logits = outputs[0]
print(logits.shape)
"""
torch.Size([20, 47, 50265])
(N, L, |V|)
"""
break
if __name__ == '__main__':
main()
|