File size: 1,149 Bytes
d8d4c8d
 
642d911
d8d4c8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from transformers import BartTokenizer, BartForConditionalGeneration

from datamodules import IdiomifyDataModule


CONFIG = {
    "literal2idiomatic_ver": "pie_v0",
    "batch_size": 20,
    "num_workers": 4,
    "shuffle": True
}


def main():
    tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
    bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
    datamodule = IdiomifyDataModule(CONFIG, tokenizer)
    datamodule.prepare_data()
    datamodule.setup()
    for batch in datamodule.train_dataloader():
        srcs, tgts_r, tgts = batch
        input_ids, attention_mask = srcs[:, 0], srcs[:, 1]  # noqa
        decoder_input_ids, decoder_attention_mask = tgts_r[:, 0], tgts_r[:, 1]
        outputs = bart(input_ids=input_ids,
                       attention_mask=attention_mask,
                       decoder_input_ids=decoder_input_ids,
                       decoder_attention_mask=decoder_attention_mask)
        logits = outputs[0]
        print(logits.shape)
        """
        torch.Size([20, 47, 50265])
        (N, L, |V|)
        """

        break


if __name__ == '__main__':
    main()