taka-yamakoshi
commited on
Commit
•
b04411c
1
Parent(s):
32ab467
fix
Browse files
app.py
CHANGED
@@ -56,11 +56,13 @@ def load_model(model_name):
|
|
56 |
from skeleton_modeling_bert import SkeletonBertForMaskedLM
|
57 |
tokenizer = BertTokenizer.from_pretrained(model_name)
|
58 |
model = BertForMaskedLM.from_pretrained(model_name)
|
|
|
59 |
elif model_name.startswith('roberta'):
|
60 |
from transformers import RobertaTokenizer, RobertaForMaskedLM
|
61 |
from skeleton_modeling_roberta import SkeletonRobertaForMaskedLM
|
62 |
tokenizer = RobertaTokenizer.from_pretrained(model_name)
|
63 |
model = RobertaForMaskedLM.from_pretrained(model_name)
|
|
|
64 |
return tokenizer,model,skeleton_model
|
65 |
|
66 |
def clear_data():
|
|
|
56 |
from skeleton_modeling_bert import SkeletonBertForMaskedLM
|
57 |
tokenizer = BertTokenizer.from_pretrained(model_name)
|
58 |
model = BertForMaskedLM.from_pretrained(model_name)
|
59 |
+
skeleton_model = SkeletonBertForMaskedLM
|
60 |
elif model_name.startswith('roberta'):
|
61 |
from transformers import RobertaTokenizer, RobertaForMaskedLM
|
62 |
from skeleton_modeling_roberta import SkeletonRobertaForMaskedLM
|
63 |
tokenizer = RobertaTokenizer.from_pretrained(model_name)
|
64 |
model = RobertaForMaskedLM.from_pretrained(model_name)
|
65 |
+
skeleton_model = SkeletonRobertaForMaskedLM
|
66 |
return tokenizer,model,skeleton_model
|
67 |
|
68 |
def clear_data():
|