File size: 929 Bytes
44db343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from models.transformer import TransformerWithTR
from models.collator import *
from transformers import AutoTokenizer
import transformers
from models.tokenizer import TokenAligner
from dataset.vocab import Vocab

class ModelWrapper:

    def __init__(self, model, vocab: Vocab):
        self.model_name = model

        if model == "tfmwtr":
            self.tokenizer = AutoTokenizer.from_pretrained("vinai/bartpho-word-base")
            self.tokenAligner = TokenAligner(self.tokenizer, vocab)
            self.bart = transformers.MBartForConditionalGeneration.from_pretrained("vinai/bartpho-word-base")
            self.model = TransformerWithTR(self.bart, self.tokenizer.pad_token_id)
            self.collator = DataCollatorForCharacterTransformer(self.tokenAligner)
            # self.model.resize_token_embeddings(self.tokenAligner)
        else:
            raise(Exception(f"Model {model} isn't implemented!"))