Spaces:
Runtime error
Runtime error
d.tsimerman
commited on
Commit
•
5a9608a
1
Parent(s):
b5d743b
good
Browse files
app.py
CHANGED
@@ -10,10 +10,15 @@ model_name = st.selectbox(
|
|
10 |
('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base', 'tinkoff-ai/response-quality-classifier-large')
|
11 |
)
|
12 |
auth_token = os.environ.get('TOKEN') or True
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
context_3 = 'привет'
|
19 |
context_2 = 'привет!'
|
@@ -83,12 +88,20 @@ def merge_dialog_data(
|
|
83 |
result[model_input_name] = result[model_input_name].cuda()
|
84 |
return result
|
85 |
|
86 |
-
tokenized_dialog = tokenize_dialog_data(tokenizer, sample, MAX_SEQ_LENGTH, sorted_dialog_columns)
|
87 |
-
tokens = merge_dialog_data(tokenizer, tokenized_dialog)
|
88 |
-
with torch.inference_mode():
|
89 |
-
logits = model(**tokens).logits
|
90 |
-
probas = torch.sigmoid(logits)[0].cpu().detach().numpy()
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
st.metric(
|
93 |
label='Вероятность того, что последний ответ диалогового агента релевантный',
|
94 |
value=probas[0]
|
10 |
('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base', 'tinkoff-ai/response-quality-classifier-large')
|
11 |
)
|
12 |
auth_token = os.environ.get('TOKEN') or True
|
13 |
+
|
14 |
+
@st.cache
|
15 |
+
def load_model(model_name: str):
|
16 |
+
with st.spinner('Loading models...'):
|
17 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
|
18 |
+
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name, use_auth_token=auth_token)
|
19 |
+
if torch.cuda.is_available():
|
20 |
+
model = model.cuda()
|
21 |
+
return tokenizer, model
|
22 |
|
23 |
context_3 = 'привет'
|
24 |
context_2 = 'привет!'
|
88 |
result[model_input_name] = result[model_input_name].cuda()
|
89 |
return result
|
90 |
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
+
@st.cache
|
93 |
+
def inference(model_name: str, sample: dict):
|
94 |
+
tokenizer, model = load_model(model_name)
|
95 |
+
with st.spinner('Running inference...'):
|
96 |
+
tokenized_dialog = tokenize_dialog_data(tokenizer, sample, MAX_SEQ_LENGTH, sorted_dialog_columns)
|
97 |
+
tokens = merge_dialog_data(tokenizer, tokenized_dialog)
|
98 |
+
with torch.inference_mode():
|
99 |
+
logits = model(**tokens).logits
|
100 |
+
probas = torch.sigmoid(logits)[0].cpu().detach().numpy()
|
101 |
+
return probas
|
102 |
+
|
103 |
+
|
104 |
+
probas = inference(model_name, sample)
|
105 |
st.metric(
|
106 |
label='Вероятность того, что последний ответ диалогового агента релевантный',
|
107 |
value=probas[0]
|