hgrif commited on
Commit
81d3f6d
β€’
1 Parent(s): 9204ef7

Switch to AutoModels

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -11,7 +11,7 @@ import numpy as np
11
  import tensorflow as tf
12
  import streamlit as st
13
  from gazpacho import Soup, get
14
- from transformers import BertTokenizer, TFBertForMaskedLM
15
 
16
 
17
  DEFAULT_QUERY = "Machines will take over the world soon"
@@ -85,8 +85,8 @@ def start_rhyming(query, rhyme_words_options):
85
  @st.cache(allow_output_mutation=True)
86
  def load_model(model_path):
87
  return (
88
- TFBertForMaskedLM.from_pretrained(model_path),
89
- BertTokenizer.from_pretrained(model_path),
90
  )
91
 
92
 
@@ -119,8 +119,8 @@ class TokenWeighter:
119
  class RhymeGenerator:
120
  def __init__(
121
  self,
122
- model: TFBertForMaskedLM,
123
- tokenizer: BertTokenizer,
124
  token_weighter: TokenWeighter = None,
125
  ):
126
  """Generate rhymes.
 
11
  import tensorflow as tf
12
  import streamlit as st
13
  from gazpacho import Soup, get
14
+ from transformers import AutoTokenizer, TFAutoModelForMaskedLM
15
 
16
 
17
  DEFAULT_QUERY = "Machines will take over the world soon"
 
85
  @st.cache(allow_output_mutation=True)
86
  def load_model(model_path):
87
  return (
88
+ TFAutoModelForMaskedLM.from_pretrained(model_path),
89
+ AutoTokenizer.from_pretrained(model_path),
90
  )
91
 
92
 
 
119
  class RhymeGenerator:
120
  def __init__(
121
  self,
122
+ model: TFAutoModelForMaskedLM,
123
+ tokenizer: AutoTokenizer,
124
  token_weighter: TokenWeighter = None,
125
  ):
126
  """Generate rhymes.