d.tsimerman commited on
Commit
5a9608a
1 Parent(s): b5d743b
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -10,10 +10,15 @@ model_name = st.selectbox(
10
  ('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base', 'tinkoff-ai/response-quality-classifier-large')
11
  )
12
  auth_token = os.environ.get('TOKEN') or True
13
- tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
14
- model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name, use_auth_token=auth_token)
15
- if torch.cuda.is_available():
16
- model = model.cuda()
 
 
 
 
 
17
 
18
  context_3 = 'привет'
19
  context_2 = 'привет!'
@@ -83,12 +88,20 @@ def merge_dialog_data(
83
  result[model_input_name] = result[model_input_name].cuda()
84
  return result
85
 
86
- tokenized_dialog = tokenize_dialog_data(tokenizer, sample, MAX_SEQ_LENGTH, sorted_dialog_columns)
87
- tokens = merge_dialog_data(tokenizer, tokenized_dialog)
88
- with torch.inference_mode():
89
- logits = model(**tokens).logits
90
- probas = torch.sigmoid(logits)[0].cpu().detach().numpy()
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  st.metric(
93
  label='Вероятность того, что последний ответ диалогового агента релевантный',
94
  value=probas[0]
10
  ('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base', 'tinkoff-ai/response-quality-classifier-large')
11
  )
12
  auth_token = os.environ.get('TOKEN') or True
13
+
14
+ @st.cache
15
+ def load_model(model_name: str):
16
+ with st.spinner('Loading models...'):
17
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
18
+ model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name, use_auth_token=auth_token)
19
+ if torch.cuda.is_available():
20
+ model = model.cuda()
21
+ return tokenizer, model
22
 
23
  context_3 = 'привет'
24
  context_2 = 'привет!'
88
  result[model_input_name] = result[model_input_name].cuda()
89
  return result
90
 
 
 
 
 
 
91
 
92
+ @st.cache
93
+ def inference(model_name: str, sample: dict):
94
+ tokenizer, model = load_model(model_name)
95
+ with st.spinner('Running inference...'):
96
+ tokenized_dialog = tokenize_dialog_data(tokenizer, sample, MAX_SEQ_LENGTH, sorted_dialog_columns)
97
+ tokens = merge_dialog_data(tokenizer, tokenized_dialog)
98
+ with torch.inference_mode():
99
+ logits = model(**tokens).logits
100
+ probas = torch.sigmoid(logits)[0].cpu().detach().numpy()
101
+ return probas
102
+
103
+
104
+ probas = inference(model_name, sample)
105
  st.metric(
106
  label='Вероятность того, что последний ответ диалогового агента релевантный',
107
  value=probas[0]