crscardellino commited on
Commit
f781de0
1 Parent(s): 534773e

Added device to the steps

Browse files
Files changed (4) hide show
  1. .flake8 +8 -0
  2. chatbot.py +65 -55
  3. flisol-cordoba-2023.ipynb +35 -31
  4. utils.py +8 -7
.flake8 ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ exclude =
3
+ __pycache__
4
+ migrations/
5
+ .venv/
6
+ *venv/
7
+ max-line-length = 100
8
+ extend-ignore = E203,E501,E701
chatbot.py CHANGED
@@ -17,8 +17,12 @@ prompt.
17
  import argparse
18
  import torch
19
 
20
- from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel,\
21
- PreTrainedTokenizerBase
 
 
 
 
22
  from typing import Optional, Union
23
 
24
 
@@ -51,42 +55,47 @@ class ChatBot:
51
  bot_identifier : str
52
  The string that will identify the bot speaker in the prompt (e.g.
53
  EXPERT).
 
 
54
  """
55
 
56
- def __init__(self,
57
- base_model: Union[str, PreTrainedModel],
58
- tokenizer: Optional[PreTrainedTokenizerBase] = None,
59
- initial_prompt: Optional[str] = None,
60
- keep_context: bool = False,
61
- creative: bool = False,
62
- max_tokens: int = 50,
63
- human_identifier: str = 'HUMAN',
64
- bot_identifier: str = 'EXPERT'):
 
 
 
65
  if isinstance(base_model, str):
66
  self.model = AutoModelForCausalLM.from_pretrained(
67
- base_model,
68
- low_cpu_mem_usage=True,
69
- torch_dtype='auto'
70
- )
71
  self.tokenizer = AutoTokenizer.from_pretrained(base_model)
72
  else:
73
- assert isinstance(tokenizer, PreTrainedTokenizerBase),\
74
- "If the base model is given, the tokenizer should be given as well"
75
- self.model = base_model
 
76
  self.tokenizer = tokenizer
77
 
78
  if initial_prompt is None:
79
- with open('./prompt.txt', 'r') as fh:
80
  self.initial_prompt = fh.read()
81
  else:
82
  self.initial_prompt = initial_prompt
83
 
84
  self.keep_context = keep_context
85
- self.context = ''
86
  self.creative = creative
87
  self.max_tokens = max_tokens
88
  self.human_identifier = human_identifier
89
  self.bot_identifier = bot_identifier
 
90
 
91
  def chat(self, input_text: str) -> str:
92
  """
@@ -113,10 +122,10 @@ class ChatBot:
113
  # start the dialog between the human and the bot. Give space for the
114
  # model to continue from the prompt
115
  prompt = self.initial_prompt + self.context
116
- prompt += f'{self.human_identifier}: {input_text}\n'
117
- prompt += f'{self.bot_identifier}: ' # check the space after the colon
118
 
119
- input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
120
  if self.creative:
121
  # In case you want the bot to be creative, we sample using `top_k`
122
  # and `top_p`
@@ -125,13 +134,12 @@ class ChatBot:
125
  do_sample=True,
126
  max_length=input_ids.shape[1] + self.max_tokens,
127
  top_k=50,
128
- top_p=0.95
129
  )[0]
130
  else:
131
  # Otherwise we return the most probable token
132
  output = self.model.generate(
133
- input_ids,
134
- max_length=input_ids.shape[1] + self.max_tokens
135
  )[0]
136
 
137
  # Decode the output, removing special tokens for the model (like
@@ -139,11 +147,11 @@ class ChatBot:
139
  decoded_output = self.tokenizer.decode(output, skip_special_tokens=True)
140
 
141
  # Trim the output, first by removing the original prompt
142
- trimmed_output = decoded_output[len(prompt):]
143
 
144
  # Then we find the stop token, in this case the human identifier, and
145
  # we get up to that point
146
- trimmed_output = trimmed_output[:trimmed_output.find(f'{self.human_identifier}:')]
147
 
148
  if self.keep_context:
149
  # If we want to keep the context of the conversation we add the
@@ -153,36 +161,38 @@ class ChatBot:
153
  return trimmed_output.strip() # we only return the trimmed output
154
 
155
 
156
- if __name__ == '__main__':
157
  parser = argparse.ArgumentParser()
158
- parser.add_argument('--model-name', '-m',
159
- default='bigscience/bloom-560m',
160
- help="Name of the base model to use for the chatbot")
161
- parser.add_argument('--prompt', '-p',
162
- default='./prompt.txt',
163
- help="Path to the file with the prompt to use")
164
- parser.add_argument('--keep-context', '-k',
165
- action='store_true',
166
- help="Keep context of the conversation.")
167
- parser.add_argument('--creative', '-c',
168
- action='store_true',
169
- help="Make the bot creative when answering.")
170
- parser.add_argument('--random-seed', '-r',
171
- default=42,
172
- help="Seed number for the creative bot.",
173
- type=int)
174
- parser.add_argument('--human-identifier', '-i',
175
- default='HUMANO',
176
- help="Name of the human identifier.")
177
- parser.add_argument('--bot-identifier', '-b',
178
- default='EXPERTO',
179
- help="Name of the bot identifier.")
 
 
180
 
181
  args = parser.parse_args()
182
 
183
  torch.manual_seed(args.random_seed)
184
 
185
- with open(args.prompt, 'r') as fh:
186
  initial_prompt = fh.read()
187
 
188
  chatbot = ChatBot(
@@ -191,12 +201,12 @@ if __name__ == '__main__':
191
  keep_context=args.keep_context,
192
  creative=args.creative,
193
  human_identifier=args.human_identifier,
194
- bot_identifier=args.bot_identifier
195
  )
196
 
197
  print("Write `exit` or `quit` to quit")
198
  while True:
199
- input_text = input('> ')
200
- if input_text == 'exit' or input_text == 'quit':
201
  break
202
  print(chatbot.chat(input_text))
 
17
  import argparse
18
  import torch
19
 
20
+ from transformers import (
21
+ AutoModelForCausalLM,
22
+ AutoTokenizer,
23
+ PreTrainedModel,
24
+ PreTrainedTokenizerBase,
25
+ )
26
  from typing import Optional, Union
27
 
28
 
 
55
  bot_identifier : str
56
  The string that will identify the bot speaker in the prompt (e.g.
57
  EXPERT).
58
+ device: torch.device
59
+ Device to run the model
60
  """
61
 
62
+ def __init__(
63
+ self,
64
+ base_model: Union[str, PreTrainedModel],
65
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
66
+ initial_prompt: Optional[str] = None,
67
+ keep_context: bool = False,
68
+ creative: bool = False,
69
+ max_tokens: int = 50,
70
+ human_identifier: str = "HUMAN",
71
+ bot_identifier: str = "EXPERT",
72
+ device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
73
+ ):
74
  if isinstance(base_model, str):
75
  self.model = AutoModelForCausalLM.from_pretrained(
76
+ base_model, low_cpu_mem_usage=True, torch_dtype="auto"
77
+ ).to(device)
 
 
78
  self.tokenizer = AutoTokenizer.from_pretrained(base_model)
79
  else:
80
+ assert isinstance(
81
+ tokenizer, PreTrainedTokenizerBase
82
+ ), "If the base model is given, the tokenizer should be given as well"
83
+ self.model = base_model.to(device)
84
  self.tokenizer = tokenizer
85
 
86
  if initial_prompt is None:
87
+ with open("./prompt.txt", "r") as fh:
88
  self.initial_prompt = fh.read()
89
  else:
90
  self.initial_prompt = initial_prompt
91
 
92
  self.keep_context = keep_context
93
+ self.context = ""
94
  self.creative = creative
95
  self.max_tokens = max_tokens
96
  self.human_identifier = human_identifier
97
  self.bot_identifier = bot_identifier
98
+ self.device = device
99
 
100
  def chat(self, input_text: str) -> str:
101
  """
 
122
  # start the dialog between the human and the bot. Give space for the
123
  # model to continue from the prompt
124
  prompt = self.initial_prompt + self.context
125
+ prompt += f"{self.human_identifier}: {input_text}\n"
126
+ prompt += f"{self.bot_identifier}: " # check the space after the colon
127
 
128
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
129
  if self.creative:
130
  # In case you want the bot to be creative, we sample using `top_k`
131
  # and `top_p`
 
134
  do_sample=True,
135
  max_length=input_ids.shape[1] + self.max_tokens,
136
  top_k=50,
137
+ top_p=0.95,
138
  )[0]
139
  else:
140
  # Otherwise we return the most probable token
141
  output = self.model.generate(
142
+ input_ids, max_length=input_ids.shape[1] + self.max_tokens
 
143
  )[0]
144
 
145
  # Decode the output, removing special tokens for the model (like
 
147
  decoded_output = self.tokenizer.decode(output, skip_special_tokens=True)
148
 
149
  # Trim the output, first by removing the original prompt
150
+ trimmed_output = decoded_output[len(prompt) :]
151
 
152
  # Then we find the stop token, in this case the human identifier, and
153
  # we get up to that point
154
+ trimmed_output = trimmed_output[: trimmed_output.find(f"{self.human_identifier}:")]
155
 
156
  if self.keep_context:
157
  # If we want to keep the context of the conversation we add the
 
161
  return trimmed_output.strip() # we only return the trimmed output
162
 
163
 
164
+ if __name__ == "__main__":
165
  parser = argparse.ArgumentParser()
166
+ parser.add_argument(
167
+ "--model-name",
168
+ "-m",
169
+ default="bigscience/bloom-560m",
170
+ help="Name of the base model to use for the chatbot",
171
+ )
172
+ parser.add_argument(
173
+ "--prompt", "-p", default="./prompt.txt", help="Path to the file with the prompt to use"
174
+ )
175
+ parser.add_argument(
176
+ "--keep-context", "-k", action="store_true", help="Keep context of the conversation."
177
+ )
178
+ parser.add_argument(
179
+ "--creative", "-c", action="store_true", help="Make the bot creative when answering."
180
+ )
181
+ parser.add_argument(
182
+ "--random-seed", "-r", default=42, help="Seed number for the creative bot.", type=int
183
+ )
184
+ parser.add_argument(
185
+ "--human-identifier", "-i", default="HUMANO", help="Name of the human identifier."
186
+ )
187
+ parser.add_argument(
188
+ "--bot-identifier", "-b", default="EXPERTO", help="Name of the bot identifier."
189
+ )
190
 
191
  args = parser.parse_args()
192
 
193
  torch.manual_seed(args.random_seed)
194
 
195
+ with open(args.prompt, "r") as fh:
196
  initial_prompt = fh.read()
197
 
198
  chatbot = ChatBot(
 
201
  keep_context=args.keep_context,
202
  creative=args.creative,
203
  human_identifier=args.human_identifier,
204
+ bot_identifier=args.bot_identifier,
205
  )
206
 
207
  print("Write `exit` or `quit` to quit")
208
  while True:
209
+ input_text = input("> ")
210
+ if input_text == "exit" or input_text == "quit":
211
  break
212
  print(chatbot.chat(input_text))
flisol-cordoba-2023.ipynb CHANGED
@@ -272,14 +272,11 @@
272
  "from IPython.display import display, HTML\n",
273
  "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
274
  "\n",
275
- "BASE_MODEL = 'bigscience/bloom-3b' # More models at https://huggingface.co/models\n",
276
  "\n",
 
277
  "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n",
278
- "model = AutoModelForCausalLM.from_pretrained(\n",
279
- " BASE_MODEL,\n",
280
- " low_cpu_mem_usage=True,\n",
281
- " torch_dtype='auto'\n",
282
- ")"
283
  ]
284
  },
285
  {
@@ -409,7 +406,7 @@
409
  ],
410
  "source": [
411
  "MAX_TOKENS = 50\n",
412
- "input_ids = tokenizer.encode(PROMPT, return_tensors='pt')\n",
413
  "greedy_output = model.generate(input_ids, max_length=input_ids.shape[1] + MAX_TOKENS)\n",
414
  "output = tokenizer.decode(greedy_output[0], skip_special_tokens=True)\n",
415
  "\n",
@@ -525,17 +522,17 @@
525
  "\"\"\".strip()\n",
526
  "\n",
527
  "chatbot = ChatBot(\n",
528
- " base_model='bigscience/bloom-3b',\n",
529
  " initial_prompt=PROMPT,\n",
530
  " keep_context=True,\n",
531
  " creative=True,\n",
532
- " human_identifier='HUMANO',\n",
533
- " bot_identifier='EXPERTO'\n",
534
  ")\n",
535
  "\n",
536
  "while True:\n",
537
- " input_text = input('> ')\n",
538
- " if input_text == 'exit':\n",
539
  " break\n",
540
  " print(chatbot.chat(input_text))"
541
  ]
@@ -619,9 +616,9 @@
619
  "import torch\n",
620
  "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
621
  "\n",
622
- "BASE_MODEL = 'DeepESP/gpt2-spanish' # We play with a smaller model\n",
623
  "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n",
624
- "model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)"
625
  ]
626
  },
627
  {
@@ -663,7 +660,7 @@
663
  "source": [
664
  "torch.manual_seed(42) # To ensure determinism\n",
665
  "\n",
666
- "input_ids = tokenizer.encode(\"Aquí me pongo a cantar\", return_tensors='pt')\n",
667
  "sampling_output = model.generate(input_ids, do_sample=True, max_length=50, top_k=50, top_p=0.95)\n",
668
  "output = tokenizer.decode(sampling_output[0], skip_special_tokens=True)\n",
669
  "\n",
@@ -716,9 +713,14 @@
716
  "source": [
717
  "from datasets import load_dataset\n",
718
  "\n",
719
- "datasets = load_dataset('text', data_files={'train': './data/martin-fierro_train.txt',\n",
720
- " 'validation': './data/martin-fierro_validation.txt'})\n",
721
- "print('\\n'.join(datasets['train'][:9]['text']))"
 
 
 
 
 
722
  ]
723
  },
724
  {
@@ -750,7 +752,9 @@
750
  "source": [
751
  "from utils import tokenize # local module in the repository\n",
752
  "\n",
753
- "tokenized_datasets = datasets.map(tokenize(tokenizer), batched=True, num_proc=4, remove_columns=['text'])"
 
 
754
  ]
755
  },
756
  {
@@ -831,8 +835,8 @@
831
  }
832
  ],
833
  "source": [
834
- "print(len(lm_datasets['train'][0]['input_ids']))\n",
835
- "print(lm_datasets['train'][0]['input_ids'][:10])"
836
  ]
837
  },
838
  {
@@ -876,7 +880,7 @@
876
  }
877
  ],
878
  "source": [
879
- "print(tokenizer.decode(lm_datasets['train'][0]['input_ids']))"
880
  ]
881
  },
882
  {
@@ -1022,24 +1026,24 @@
1022
  "from transformers import Trainer, TrainingArguments\n",
1023
  "\n",
1024
  "training_args = TrainingArguments(\n",
1025
- " 'flisol-cba-martin-fierro',\n",
1026
- " evaluation_strategy='epoch',\n",
1027
  " num_train_epochs=10,\n",
1028
  " learning_rate=2e-5,\n",
1029
  " weight_decay=0.01,\n",
1030
- " logging_steps=5\n",
1031
  ")\n",
1032
  "\n",
1033
  "trainer = Trainer(\n",
1034
  " model=model,\n",
1035
  " args=training_args,\n",
1036
- " train_dataset=lm_datasets['train'],\n",
1037
- " eval_dataset=lm_datasets['validation']\n",
1038
  ")\n",
1039
  "\n",
1040
  "trainer.train()\n",
1041
  "trainer.push_to_hub() # This pushes the trained model to Hugging Face model repository\n",
1042
- "tokenizer.push_to_hub('flisol-cba-martin-fierro')"
1043
  ]
1044
  },
1045
  {
@@ -1088,13 +1092,13 @@
1088
  "import torch\n",
1089
  "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
1090
  "\n",
1091
- "MODEL = 'flisol-cba-martin-fierro'\n",
1092
  "tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
1093
- "model = AutoModelForCausalLM.from_pretrained(MODEL)\n",
1094
  "\n",
1095
  "torch.manual_seed(42) # To ensure determinism\n",
1096
  "\n",
1097
- "input_ids = tokenizer.encode(\"Aquí me pongo a cantar\", return_tensors='pt')\n",
1098
  "sampling_output = model.generate(input_ids, do_sample=True, max_length=50, top_k=50, top_p=0.95)\n",
1099
  "output = tokenizer.decode(sampling_output[0], skip_special_tokens=True)\n",
1100
  "\n",
 
272
  "from IPython.display import display, HTML\n",
273
  "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
274
  "\n",
275
+ "BASE_MODEL = \"bigscience/bloom-3b\" # More models at https://huggingface.co/models\n",
276
  "\n",
277
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
278
  "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n",
279
+ "model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, low_cpu_mem_usage=True, torch_dtype=\"auto\").to(device)"
 
 
 
 
280
  ]
281
  },
282
  {
 
406
  ],
407
  "source": [
408
  "MAX_TOKENS = 50\n",
409
+ "input_ids = tokenizer.encode(PROMPT, return_tensors=\"pt\").to(device)\n",
410
  "greedy_output = model.generate(input_ids, max_length=input_ids.shape[1] + MAX_TOKENS)\n",
411
  "output = tokenizer.decode(greedy_output[0], skip_special_tokens=True)\n",
412
  "\n",
 
522
  "\"\"\".strip()\n",
523
  "\n",
524
  "chatbot = ChatBot(\n",
525
+ " base_model=\"bigscience/bloom-3b\",\n",
526
  " initial_prompt=PROMPT,\n",
527
  " keep_context=True,\n",
528
  " creative=True,\n",
529
+ " human_identifier=\"HUMANO\",\n",
530
+ " bot_identifier=\"EXPERTO\",\n",
531
  ")\n",
532
  "\n",
533
  "while True:\n",
534
+ " input_text = input(\"> \")\n",
535
+ " if input_text == \"exit\":\n",
536
  " break\n",
537
  " print(chatbot.chat(input_text))"
538
  ]
 
616
  "import torch\n",
617
  "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
618
  "\n",
619
+ "BASE_MODEL = \"DeepESP/gpt2-spanish\" # We play with a smaller model\n",
620
  "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n",
621
+ "model = AutoModelForCausalLM.from_pretrained(BASE_MODEL).to(device)"
622
  ]
623
  },
624
  {
 
660
  "source": [
661
  "torch.manual_seed(42) # To ensure determinism\n",
662
  "\n",
663
+ "input_ids = tokenizer.encode(\"Aquí me pongo a cantar\", return_tensors=\"pt\").to(device)\n",
664
  "sampling_output = model.generate(input_ids, do_sample=True, max_length=50, top_k=50, top_p=0.95)\n",
665
  "output = tokenizer.decode(sampling_output[0], skip_special_tokens=True)\n",
666
  "\n",
 
713
  "source": [
714
  "from datasets import load_dataset\n",
715
  "\n",
716
+ "datasets = load_dataset(\n",
717
+ " \"text\",\n",
718
+ " data_files={\n",
719
+ " \"train\": \"./data/martin-fierro_train.txt\",\n",
720
+ " \"validation\": \"./data/martin-fierro_validation.txt\",\n",
721
+ " },\n",
722
+ ")\n",
723
+ "print(\"\\n\".join(datasets[\"train\"][:9][\"text\"]))"
724
  ]
725
  },
726
  {
 
752
  "source": [
753
  "from utils import tokenize # local module in the repository\n",
754
  "\n",
755
+ "tokenized_datasets = datasets.map(\n",
756
+ " tokenize(tokenizer), batched=True, num_proc=4, remove_columns=[\"text\"]\n",
757
+ ")"
758
  ]
759
  },
760
  {
 
835
  }
836
  ],
837
  "source": [
838
+ "print(len(lm_datasets[\"train\"][0][\"input_ids\"]))\n",
839
+ "print(lm_datasets[\"train\"][0][\"input_ids\"][:10])"
840
  ]
841
  },
842
  {
 
880
  }
881
  ],
882
  "source": [
883
+ "print(tokenizer.decode(lm_datasets[\"train\"][0][\"input_ids\"]))"
884
  ]
885
  },
886
  {
 
1026
  "from transformers import Trainer, TrainingArguments\n",
1027
  "\n",
1028
  "training_args = TrainingArguments(\n",
1029
+ " \"flisol-cba-martin-fierro\",\n",
1030
+ " evaluation_strategy=\"epoch\",\n",
1031
  " num_train_epochs=10,\n",
1032
  " learning_rate=2e-5,\n",
1033
  " weight_decay=0.01,\n",
1034
+ " logging_steps=5,\n",
1035
  ")\n",
1036
  "\n",
1037
  "trainer = Trainer(\n",
1038
  " model=model,\n",
1039
  " args=training_args,\n",
1040
+ " train_dataset=lm_datasets[\"train\"],\n",
1041
+ " eval_dataset=lm_datasets[\"validation\"],\n",
1042
  ")\n",
1043
  "\n",
1044
  "trainer.train()\n",
1045
  "trainer.push_to_hub() # This pushes the trained model to Hugging Face model repository\n",
1046
+ "tokenizer.push_to_hub(\"flisol-cba-martin-fierro\")"
1047
  ]
1048
  },
1049
  {
 
1092
  "import torch\n",
1093
  "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
1094
  "\n",
1095
+ "MODEL = \"flisol-cba-martin-fierro\"\n",
1096
  "tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
1097
+ "model = AutoModelForCausalLM.from_pretrained(MODEL).to(device)\n",
1098
  "\n",
1099
  "torch.manual_seed(42) # To ensure determinism\n",
1100
  "\n",
1101
+ "input_ids = tokenizer.encode(\"Aquí me pongo a cantar\", return_tensors=\"pt\").to(device)\n",
1102
  "sampling_output = model.generate(input_ids, do_sample=True, max_length=50, top_k=50, top_p=0.95)\n",
1103
  "output = tokenizer.decode(sampling_output[0], skip_special_tokens=True)\n",
1104
  "\n",
utils.py CHANGED
@@ -23,8 +23,9 @@ from transformers import PreTrainedTokenizerBase
23
  from typing import Callable, Dict, List
24
 
25
 
26
- def tokenize(tokenizer: PreTrainedTokenizerBase,
27
- end_char: str = '\n') -> Callable[[Dict[str, List[str]]], DatasetDict]:
 
28
  """
29
  Helper function that returns a function to use with the `map` method of
30
  datasets.DatasetDict. It takes a tokenizer and generates a function that
@@ -47,14 +48,14 @@ def tokenize(tokenizer: PreTrainedTokenizerBase,
47
  The function in charge of the tokenization process.
48
 
49
  """
 
50
  def _tokenize(examples: Dict[str, List[str]]) -> DatasetDict:
51
- return tokenizer([f'{e}{end_char}' for e in examples['text']])
52
 
53
  return _tokenize
54
 
55
 
56
- def group_texts(examples: Dict[str, List[int]],
57
- block_size: int = 128) -> Dict[str, List[int]]:
58
  """
59
  Helper function to concatenate a tokenized dataset (with the function above)
60
  in chunks of `block_size`. The code was taken from
@@ -80,13 +81,13 @@ def group_texts(examples: Dict[str, List[int]],
80
  """
81
  # Concatenate all texts.
82
  concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
83
- total_length = len(concatenated_examples['input_ids'])
84
  # We drop the small remainder, we could add padding if the model supported
85
  # it instead of this drop, you can customize this part to your needs
86
  total_length = (total_length // block_size) * block_size
87
  # Split by chunks of block_size length
88
  result = {
89
- k: [t[i:i + block_size] for i in range(0, total_length, block_size)]
90
  for k, t in concatenated_examples.items()
91
  }
92
  # labels to be used by the training phase, it copies since the Transformers
 
23
  from typing import Callable, Dict, List
24
 
25
 
26
+ def tokenize(
27
+ tokenizer: PreTrainedTokenizerBase, end_char: str = "\n"
28
+ ) -> Callable[[Dict[str, List[str]]], DatasetDict]:
29
  """
30
  Helper function that returns a function to use with the `map` method of
31
  datasets.DatasetDict. It takes a tokenizer and generates a function that
 
48
  The function in charge of the tokenization process.
49
 
50
  """
51
+
52
  def _tokenize(examples: Dict[str, List[str]]) -> DatasetDict:
53
+ return tokenizer([f"{e}{end_char}" for e in examples["text"]])
54
 
55
  return _tokenize
56
 
57
 
58
+ def group_texts(examples: Dict[str, List[int]], block_size: int = 128) -> Dict[str, List[int]]:
 
59
  """
60
  Helper function to concatenate a tokenized dataset (with the function above)
61
  in chunks of `block_size`. The code was taken from
 
81
  """
82
  # Concatenate all texts.
83
  concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
84
+ total_length = len(concatenated_examples["input_ids"])
85
  # We drop the small remainder, we could add padding if the model supported
86
  # it instead of this drop, you can customize this part to your needs
87
  total_length = (total_length // block_size) * block_size
88
  # Split by chunks of block_size length
89
  result = {
90
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
91
  for k, t in concatenated_examples.items()
92
  }
93
  # labels to be used by the training phase, it copies since the Transformers