|
from transformers import BartTokenizer, BartForConditionalGeneration |
|
|
|
from datamodules import IdiomifyDataModule |
|
|
|
|
|
CONFIG = { |
|
"literal2idiomatic_ver": "pie_v0", |
|
"batch_size": 20, |
|
"num_workers": 4, |
|
"shuffle": True |
|
} |
|
|
|
|
|
def main(): |
|
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') |
|
bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large') |
|
datamodule = IdiomifyDataModule(CONFIG, tokenizer) |
|
datamodule.prepare_data() |
|
datamodule.setup() |
|
for batch in datamodule.train_dataloader(): |
|
srcs, tgts_r, tgts = batch |
|
input_ids, attention_mask = srcs[:, 0], srcs[:, 1] |
|
decoder_input_ids, decoder_attention_mask = tgts_r[:, 0], tgts_r[:, 1] |
|
outputs = bart(input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
decoder_input_ids=decoder_input_ids, |
|
decoder_attention_mask=decoder_attention_mask) |
|
logits = outputs[0] |
|
print(logits.shape) |
|
""" |
|
torch.Size([20, 47, 50265]) |
|
(N, L, |V|) |
|
""" |
|
|
|
break |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|