d.tsimerman commited on
Commit
bf1f054
1 Parent(s): f5e42ed
Files changed (1) hide show
  1. app.py +9 -10
app.py CHANGED
@@ -12,7 +12,7 @@ model_name = st.selectbox(
12
  )
13
  auth_token = os.environ.get('TOKEN') or True
14
 
15
- @st.cache(hash_funcs={tokenizers.Tokenizer: lambda tokenizer: hash(tokenizer.to_str())})
16
  def load_model(model_name: str):
17
  with st.spinner('Loading models...'):
18
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
@@ -93,16 +93,15 @@ def merge_dialog_data(
93
  @st.cache
94
  def inference(model_name: str, sample: dict):
95
  tokenizer, model = load_model(model_name)
96
- with st.spinner('Running inference...'):
97
- tokenized_dialog = tokenize_dialog_data(tokenizer, sample, MAX_SEQ_LENGTH, sorted_dialog_columns)
98
- tokens = merge_dialog_data(tokenizer, tokenized_dialog)
99
- with torch.inference_mode():
100
- logits = model(**tokens).logits
101
- probas = torch.sigmoid(logits)[0].cpu().detach().numpy()
102
- return probas
103
 
104
-
105
- probas = inference(model_name, sample)
106
  st.metric(
107
  label='Вероятность того, что последний ответ диалогового агента релевантный',
108
  value=probas[0]
12
  )
13
  auth_token = os.environ.get('TOKEN') or True
14
 
15
+ @st.cache(hash_funcs={tokenizers.Tokenizer: lambda tokenizer: hash(tokenizer.to_str())}, allow_output_mutation=True)
16
  def load_model(model_name: str):
17
  with st.spinner('Loading models...'):
18
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
93
  @st.cache
94
  def inference(model_name: str, sample: dict):
95
  tokenizer, model = load_model(model_name)
96
+ tokenized_dialog = tokenize_dialog_data(tokenizer, sample, MAX_SEQ_LENGTH, sorted_dialog_columns)
97
+ tokens = merge_dialog_data(tokenizer, tokenized_dialog)
98
+ with torch.inference_mode():
99
+ logits = model(**tokens).logits
100
+ probas = torch.sigmoid(logits)[0].cpu().detach().numpy()
101
+ return probas
 
102
 
103
+ with st.spinner('Running inference...'):
104
+ probas = inference(model_name, sample)
105
  st.metric(
106
  label='Вероятность того, что последний ответ диалогового агента релевантный',
107
  value=probas[0]