taka-yamakoshi commited on
Commit
b04411c
1 Parent(s): 32ab467
Files changed (1) hide show
  1. app.py +2 -0
app.py CHANGED
@@ -56,11 +56,13 @@ def load_model(model_name):
56
  from skeleton_modeling_bert import SkeletonBertForMaskedLM
57
  tokenizer = BertTokenizer.from_pretrained(model_name)
58
  model = BertForMaskedLM.from_pretrained(model_name)
 
59
  elif model_name.startswith('roberta'):
60
  from transformers import RobertaTokenizer, RobertaForMaskedLM
61
  from skeleton_modeling_roberta import SkeletonRobertaForMaskedLM
62
  tokenizer = RobertaTokenizer.from_pretrained(model_name)
63
  model = RobertaForMaskedLM.from_pretrained(model_name)
 
64
  return tokenizer,model,skeleton_model
65
 
66
  def clear_data():
 
56
  from skeleton_modeling_bert import SkeletonBertForMaskedLM
57
  tokenizer = BertTokenizer.from_pretrained(model_name)
58
  model = BertForMaskedLM.from_pretrained(model_name)
59
+ skeleton_model = SkeletonBertForMaskedLM
60
  elif model_name.startswith('roberta'):
61
  from transformers import RobertaTokenizer, RobertaForMaskedLM
62
  from skeleton_modeling_roberta import SkeletonRobertaForMaskedLM
63
  tokenizer = RobertaTokenizer.from_pretrained(model_name)
64
  model = RobertaForMaskedLM.from_pretrained(model_name)
65
+ skeleton_model = SkeletonRobertaForMaskedLM
66
  return tokenizer,model,skeleton_model
67
 
68
  def clear_data():