hoang1007's picture
Upload 69 files
44db343
raw
history blame
No virus
929 Bytes
from models.transformer import TransformerWithTR
from models.collator import *
from transformers import AutoTokenizer
import transformers
from models.tokenizer import TokenAligner
from dataset.vocab import Vocab
class ModelWrapper:
def __init__(self, model, vocab: Vocab):
self.model_name = model
if model == "tfmwtr":
self.tokenizer = AutoTokenizer.from_pretrained("vinai/bartpho-word-base")
self.tokenAligner = TokenAligner(self.tokenizer, vocab)
self.bart = transformers.MBartForConditionalGeneration.from_pretrained("vinai/bartpho-word-base")
self.model = TransformerWithTR(self.bart, self.tokenizer.pad_token_id)
self.collator = DataCollatorForCharacterTransformer(self.tokenAligner)
# self.model.resize_token_embeddings(self.tokenAligner)
else:
raise(Exception(f"Model {model} isn't implemented!"))