Katsumata420's picture
Upload scripts
bff929a
raw
history blame
No virus
823 Bytes
from transformers import BertJapaneseTokenizer
from transformers import BertConfig
from transformers import BertForMaskedLM
from transformers import pipeline
inputs = ['[MASK]ใ‚‚ใใ†ๆ€ใ„ใพใ™', '[MASK]ใชใ‚“ใจใ„ใ†ใ‹ใใฎ', 'ใ“ใ‚Œใฏ[MASK]็งใŒๅญไพ›ใฎ้ ƒใฎ่ฉฑใชใ‚“ใงใ™ใ‘ใฉ']
model_name_list = ['models/1-6_layer-wise', 'models/tapt512_60K', 'models/dapt128-tapt512']
for input_, model_name in zip(inputs, model_name_list):
tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)
config = BertConfig.from_pretrained(model_name)
model = BertForMaskedLM.from_pretrained(model_name)
print('model name:',model_name)
print('input:',input_)
fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer, config=config)
print('output:',fill_mask(input_))
print()