egoriya commited on
Commit
fa005e0
1 Parent(s): 8bf71eb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +76 -3
README.md CHANGED
@@ -9,10 +9,10 @@ widget:
9
  example_title: "Dialog example 3"
10
  ---
11
 
12
-
13
  This classification model is based on [cointegrated/rubert-tiny2](https://huggingface.co/cointegrated/rubert-tiny2).
14
  The model should be used to produce relevance and specificity of the last message in the context of a dialog.
15
 
 
16
  It is pretrained on corpus of dialog data from social networks and finetuned on [tinkoff-ai/context_similarity](https://huggingface.co/tinkoff-ai/context_similarity).
17
  The performance of the model on validation split [tinkoff-ai/context_similarity](https://huggingface.co/tinkoff-ai/context_similarity) (with the best thresholds for validation samples):
18
 
@@ -27,8 +27,81 @@ The model can be loaded as follows:
27
 
28
  ```python
29
  # pip install transformers
 
30
  from transformers import AutoTokenizer, AutoModel
31
- tokenizer = AutoTokenizer.from_pretrained("tinkoff-ai/context_similarity")
32
- model = AutoModel.from_pretrained("tinkoff-ai/context_similarity")
 
 
33
  # model.cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  ```
 
9
  example_title: "Dialog example 3"
10
  ---
11
 
 
12
  This classification model is based on [cointegrated/rubert-tiny2](https://huggingface.co/cointegrated/rubert-tiny2).
13
  The model should be used to produce relevance and specificity of the last message in the context of a dialog.
14
 
15
+
16
  It is pretrained on corpus of dialog data from social networks and finetuned on [tinkoff-ai/context_similarity](https://huggingface.co/tinkoff-ai/context_similarity).
17
  The performance of the model on validation split [tinkoff-ai/context_similarity](https://huggingface.co/tinkoff-ai/context_similarity) (with the best thresholds for validation samples):
18
 
 
27
 
28
  ```python
29
  # pip install transformers
30
+ import transformers
31
  from transformers import AutoTokenizer, AutoModel
32
+ import torch
33
+ from typing import List, Dict
34
+ tokenizer = AutoTokenizer.from_pretrained("tinkoff-ai/response-quality-classifier-tiny")
35
+ model = AutoModel.from_pretrained("tinkoff-ai/response-quality-classifier-tiny")
36
  # model.cuda()
37
+ context_3 = 'привет'
38
+ context_2 = 'привет!'
39
+ context_1 = 'как дела?'
40
+ response = 'у меня все хорошо, а у тебя как?'
41
+
42
+ sample = {
43
+ 'context_3': context_3,
44
+ 'context_2': context_2,
45
+ 'context_1': context_1,
46
+ 'response': response
47
+ }
48
+
49
+ SEP_TOKEN = '[SEP]'
50
+ CLS_TOKEN = '[CLS]'
51
+ RESPONSE_TOKEN = '[RESPONSE_TOKEN]'
52
+ MAX_SEQ_LENGTH = 128
53
+ sorted_dialog_columns = ['context_3', 'context_2', 'context_1', 'response']
54
+
55
+ def tokenize_dialog_data(
56
+ tokenizer: transformers.PreTrainedTokenizer,
57
+ sample: Dict,
58
+ max_seq_length: int,
59
+ sorted_dialog_columns: List,
60
+ ):
61
+ """
62
+ Tokenize both contexts and response of dialog data separately
63
+ """
64
+ len_message_history = len(sorted_dialog_columns)
65
+ max_seq_length = min(max_seq_length, tokenizer.model_max_length)
66
+ max_each_message_length = max_seq_length // len_message_history - 1
67
+ messages = [sample[k] for k in sorted_dialog_columns]
68
+ result = {model_input_name: [] for model_input_name in tokenizer.model_input_names}
69
+ messages = [str(message) if message is not None else '' for message in messages]
70
+ tokens = tokenizer(
71
+ messages, padding=False, max_length=max_each_message_length, truncation=True, add_special_tokens=False
72
+ )
73
+ for model_input_name in tokens.keys():
74
+ result[model_input_name].extend(tokens[model_input_name])
75
+ return result
76
+
77
+ def merge_dialog_data(
78
+ tokenizer: transformers.PreTrainedTokenizer,
79
+ sample: Dict
80
+ ):
81
+ cls_token = tokenizer(CLS_TOKEN, add_special_tokens=False)
82
+ sep_token = tokenizer(SEP_TOKEN, add_special_tokens=False)
83
+ response_token = tokenizer(RESPONSE_TOKEN, add_special_tokens=False)
84
+ model_input_names = tokenizer.model_input_names
85
+ result = {}
86
+ for model_input_name in model_input_names:
87
+ tokens = []
88
+ tokens.extend(cls_token[model_input_name])
89
+ for i, message in enumerate(sample[model_input_name]):
90
+ tokens.extend(message)
91
+ if i < len(sample[model_input_name]) - 2:
92
+ tokens.extend(sep_token[model_input_name])
93
+ elif i == len(sample[model_input_name]) - 2:
94
+ tokens.extend(response_token[model_input_name])
95
+ result[model_input_name] = torch.tensor([tokens])
96
+ if torch.cuda.is_available():
97
+ result[model_input_name] = result[model_input_name].cuda()
98
+ return result
99
+
100
+ tokenized_dialog = tokenize_dialog_data(tokenizer, sample, MAX_SEQ_LENGTH, sorted_dialog_columns)
101
+ tokens = merge_dialog_data(tokenizer, tokenized_dialog)
102
+ with torch.inference_mode():
103
+ logits = model(**tokens).logits
104
+ probas = torch.sigmoid(logits)[0].cpu().detach().numpy()
105
+
106
+ print(probas)
107
  ```