Dilyara commited on
Commit
28c36e7
1 Parent(s): a011baf

feat: readme

Browse files
Files changed (1) hide show
  1. README.md +393 -0
README.md CHANGED
@@ -1,3 +1,396 @@
1
  ---
2
  license: openrail
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: openrail
3
+ datasets:
4
+ - AlekseyKorshuk/persona-chat
5
+ language:
6
+ - en
7
+ pipeline_tag: text-generation
8
  ---
9
+ ---
10
+ language:
11
+ - en
12
+ ---
13
+
14
+ # Model Card for Model ID
15
+
16
+ <!-- Provide a quick summary of what the model is/does. -->
17
+
18
+
19
+
20
+ # Model Details
21
+
22
+ ## Model Description
23
+
24
+ <!-- Provide a longer summary of what this model is. -->
25
+
26
+ - **Developed by:** Deeppavlov team
27
+ - **Model type:** seq2seq
28
+ - **Language(s) (NLP):** English
29
+ - **License:** MIT
30
+ - **Finetuned from model:** [facebook/bart-base](facebook/bart-base)
31
+
32
+
33
+ # Uses
34
+
35
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
36
+
37
+
38
+ ## Direct Use
39
+
40
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
41
+
42
+ ```python
43
+ from typing import List, TypedDict
44
+ from dataclasses import dataclass
45
+ from itertools import chain
46
+
47
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
48
+ import torch
49
+
50
+
51
+ @dataclass
52
+ class H2PersonaChatHyperparametersV1:
53
+ """
54
+ chat_history_pair_length: int - dialogue pairs amount from the end
55
+ """
56
+
57
+ model_name: str = "facebook/bart-base"
58
+ chat_history_pair_length: int = 7
59
+
60
+ persona_max_length: int = 14
61
+ chat_max_length: int = 25
62
+
63
+ debug_status: int = 0
64
+
65
+
66
+ class PersonaChatDatasetSampleV1(TypedDict):
67
+ """
68
+ persona: List[str] - person fact sentence set
69
+ history: List[str] - chating history
70
+ """
71
+
72
+ persona: List[str]
73
+ history: List[str]
74
+ sample_id: str
75
+
76
+
77
+ class H2Seq2SeqInferenceSampleDictV1(TypedDict):
78
+ input_ids: List[int]
79
+ attention_mask: List[int]
80
+
81
+
82
+ class H2Seq2SeqInferenceSampleDictV2(TypedDict):
83
+ input_ids: torch.Tensor
84
+ attention_mask: torch.Tensor
85
+
86
+
87
+ def flat_list(list_of_lists: List[List]) -> List:
88
+ return list(chain.from_iterable(list_of_lists))
89
+
90
+
91
+ class H2Seq2SeqInferencePersonaSampleV1:
92
+ def __init__(
93
+ self,
94
+ dataset_sample: PersonaChatDatasetSampleV1,
95
+ tokenizer: AutoTokenizer,
96
+ hyperparameters: H2PersonaChatHyperparametersV1,
97
+ ) -> None:
98
+ self.dataset_sample = dataset_sample
99
+ self.tokenizer = tokenizer
100
+ self.hyperparameters = hyperparameters
101
+
102
+ def add_spaces_after(
103
+ self,
104
+ items: List[str],
105
+ ) -> List[str]:
106
+ items = [item + " " for item in items]
107
+ return items
108
+
109
+ @property
110
+ def bos_token_id(self):
111
+ if "t5" in self.hyperparameters.model_name:
112
+ return []
113
+
114
+ if self.tokenizer.bos_token_id is None:
115
+ return []
116
+
117
+ return [self.tokenizer.bos_token_id]
118
+
119
+ @property
120
+ def eos_token_id(self):
121
+ if self.tokenizer.eos_token_id is None:
122
+ return []
123
+
124
+ return [self.tokenizer.eos_token_id]
125
+
126
+ def add_sep_beetween(self, items: List[str], sep=" EOS ") -> List[str]:
127
+ for i in range(1, len(items)):
128
+ items[i] = sep + items[i]
129
+
130
+ return items
131
+
132
+ def add_spaces_between(self, items: List[str]) -> List[str]:
133
+ items = self.add_spaces_after(items)
134
+ items[-1] = items[-1].strip()
135
+ return items
136
+
137
+ def get_sample(self) -> H2Seq2SeqInferenceSampleDictV1:
138
+
139
+ dialog_history = self.dataset_sample["history"]
140
+ dialog_history = dialog_history[-self.hyperparameters.chat_history_pair_length * 2 - 1 :]
141
+ dialog_history = self.add_sep_beetween(dialog_history)
142
+
143
+ persona = self.dataset_sample["persona"]
144
+ persona = self.add_sep_beetween(
145
+ persona,
146
+ sep=" ",
147
+ )
148
+
149
+ KNOWLEDGE_IDS = self.tokenizer.encode(
150
+ " [KNOWLEDGE] ",
151
+ add_special_tokens=False,
152
+ )
153
+ CONTEXT_IDS = self.tokenizer.encode(
154
+ " [CONTEXT] ",
155
+ add_special_tokens=False,
156
+ )
157
+
158
+ encoded_history = self.tokenizer.batch_encode_plus(
159
+ dialog_history,
160
+ add_special_tokens=False,
161
+ truncation=True,
162
+ max_length=self.hyperparameters.chat_max_length,
163
+ )
164
+ encoded_history = flat_list(encoded_history["input_ids"])
165
+
166
+ encoded_persona = self.tokenizer.batch_encode_plus(
167
+ persona,
168
+ add_special_tokens=False,
169
+ truncation=True,
170
+ max_length=self.hyperparameters.persona_max_length,
171
+ )
172
+
173
+ encoded_persona = flat_list(encoded_persona["input_ids"])
174
+
175
+ input_ids = [
176
+ *self.bos_token_id,
177
+ *CONTEXT_IDS,
178
+ *encoded_history,
179
+ *KNOWLEDGE_IDS,
180
+ *encoded_persona,
181
+ *self.eos_token_id,
182
+ ]
183
+
184
+ attention_mask = [1] * len(input_ids)
185
+
186
+ return H2Seq2SeqInferenceSampleDictV1(
187
+ input_ids=input_ids,
188
+ attention_mask=attention_mask,
189
+ )
190
+
191
+
192
+ class DialogBotV1:
193
+ def __init__(
194
+ self,
195
+ model: AutoModelForSeq2SeqLM,
196
+ tokenizer: AutoTokenizer,
197
+ hyperparameters: H2PersonaChatHyperparametersV1,
198
+ history: List[str] = None,
199
+ persona: List[str] = None,
200
+ device: str = "cuda",
201
+ shuffle_persona: bool = True,
202
+ ):
203
+ self.model = model
204
+
205
+ self.tokenizer = tokenizer
206
+ self.hyperparameters = hyperparameters
207
+ self.device = device
208
+ self.shuffle_persona = shuffle_persona
209
+
210
+ self.debug_status = hyperparameters.debug_status
211
+
212
+ if history is None:
213
+ self.history = []
214
+ self.history = history
215
+
216
+ if persona is None:
217
+ self.persona = []
218
+ self.persona = persona
219
+
220
+ def _get_sample(
221
+ self,
222
+ persona: List[str],
223
+ history: List[str],
224
+ ) -> H2Seq2SeqInferenceSampleDictV1:
225
+ dataset_sample = PersonaChatDatasetSampleV1(
226
+ persona=persona,
227
+ history=history,
228
+ )
229
+
230
+ sample = H2Seq2SeqInferencePersonaSampleV1(
231
+ tokenizer=self.tokenizer,
232
+ hyperparameters=self.hyperparameters,
233
+ dataset_sample=dataset_sample,
234
+ )
235
+ sample = sample.get_sample()
236
+ print(self.tokenizer.decode(sample['input_ids']))
237
+
238
+ for key in sample.keys():
239
+ sample[key] = torch.tensor(sample[key]).unsqueeze(0).to(self.device)
240
+
241
+ return sample
242
+
243
+ def next_response(
244
+ self,
245
+ **generation_params,
246
+ ) -> str:
247
+
248
+ sample = self._get_sample(
249
+ persona=self.persona,
250
+ history=self.history,
251
+ )
252
+ answer = self.generate_response(
253
+ sample,
254
+ **generation_params,
255
+ )
256
+ answer = self.tokenizer.batch_decode(
257
+ answer,
258
+ skip_special_tokens=True,
259
+ )
260
+ self.history.append(answer[0])
261
+ return answer[0]
262
+
263
+ def generate_response(
264
+ self,
265
+ sample: H2Seq2SeqInferenceSampleDictV1,
266
+ **generation_params,
267
+ ):
268
+ """
269
+ generation_params - https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/text_generation
270
+ """
271
+ with torch.no_grad():
272
+ return self.model.generate(
273
+ **sample,
274
+ **generation_params,
275
+ )
276
+
277
+ PRETRAINED_MODEL_NAME_OR_PATH = "DeepPavlov/bart-base-en-persona-chat"
278
+
279
+ PAIR_DIALOG_HISTORY_LENGTH = 2
280
+
281
+ # CHAT_MAX_LENGTH for single sentence, in tokens
282
+ CHAT_MAX_LENGTH = 25
283
+ # PERSONA_MAX_LENGTH for single sentence, in tokens
284
+ PERSONA_MAX_LENGTH = 19
285
+
286
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
287
+ model = AutoModelForSeq2SeqLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
288
+ model.to(device)
289
+ model.eval()
290
+
291
+ tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
292
+
293
+ if torch.cuda.is_available():
294
+ model.half()
295
+
296
+ hyperparameters = H2PersonaChatHyperparametersV1(
297
+ chat_history_pair_length=PAIR_DIALOG_HISTORY_LENGTH,
298
+ persona_max_length=PERSONA_MAX_LENGTH,
299
+ chat_max_length=CHAT_MAX_LENGTH,
300
+ model_name=PRETRAINED_MODEL_NAME_OR_PATH,
301
+ )
302
+
303
+
304
+ persona = [
305
+ "I like to play guitar.",
306
+ "I hate onions."
307
+ ]
308
+
309
+ history = [
310
+ "I hate to talk about politics, what about you?"
311
+ ]
312
+
313
+ persona_bot = DialogBotV1(
314
+ model=model,
315
+ tokenizer=tokenizer,
316
+ hyperparameters=hyperparameters,
317
+ history=history,
318
+ persona=persona,
319
+ device=device,
320
+ )
321
+
322
+ GENERATION_PARAMS = {
323
+ "max_new_tokens": 60,
324
+ "penalty_alpha": 0.15,
325
+ "top_k": 10
326
+ }
327
+ response = persona_bot.next_response(
328
+ **GENERATION_PARAMS,
329
+ )
330
+
331
+ print(response)
332
+ # i am not into politics. i am into music.
333
+ ```
334
+
335
+
336
+ ## Recommendations
337
+
338
+ # Training Details
339
+
340
+ ## Training Data
341
+
342
+ <!-- This should link to a Data Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
343
+ - [Data Source | EN Persona Chat](https://s3.amazonaws.com/datasets.huggingface.co/personachat/personachat_self_original.json)
344
+
345
+ [More Information Needed]
346
+
347
+ ### Preprocessing
348
+
349
+ - Initial data was splitted by this script:
350
+ ```python
351
+ def persona_chat_dataset_tranformer_v1(
352
+ initial_dataset_path: str,
353
+ output_folder: str,
354
+ ) -> None:
355
+ """
356
+ example
357
+ persona_chat_dataset_tranformer_v1(
358
+ initial_dataset_path="./datasets/persona_chat/persona_chat.json",
359
+ output_folder="./datasets/persona_chat",
360
+ )
361
+ """
362
+ assert initial_dataset_path is not None, "initial_dataset_path is None"
363
+ assert output_folder is not None, "output_folder is None"
364
+
365
+ with open(initial_dataset_path) as f:
366
+ initial_dataset = json.load(f)
367
+
368
+ train_dataset = initial_dataset["train"]
369
+ val_len = len(initial_dataset["valid"])
370
+ valid_dataset = initial_dataset["valid"][: val_len // 2]
371
+ test_dataset = initial_dataset["valid"][val_len // 2 :]
372
+
373
+ print(
374
+ f"Dataset lengths: train {len(train_dataset)}, valid {len(valid_dataset)}, test {len(test_dataset)}"
375
+ )
376
+ # save json files
377
+ with open(output_folder + "/train.json", "w") as f:
378
+ json.dump(train_dataset, f)
379
+
380
+ with open(output_folder + "/valid.json", "w") as f:
381
+ json.dump(valid_dataset, f)
382
+
383
+ with open(output_folder + "/test.json", "w") as f:
384
+ json.dump(test_dataset, f)
385
+
386
+ print("Datasets saved.")
387
+ ```
388
+
389
+ # Evaluation
390
+
391
+ ### Metrics
392
+
393
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
394
+ - BLUEL
395
+ - CharF
396
+ - RougeL