joaogante HF staff commited on
Commit
bd89ed8
β€’
1 Parent(s): 0b94c41

datasets refactor

Browse files
.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Initially taken from Github's Python gitignore file
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # tests and logs
12
+ tests/fixtures/cached_*_text.txt
13
+ logs/
14
+ lightning_logs/
15
+ lang_code_data/
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ .python-version
90
+
91
+ # celery beat schedule file
92
+ celerybeat-schedule
93
+
94
+ # SageMath parsed files
95
+ *.sage.py
96
+
97
+ # Environments
98
+ .env
99
+ .venv
100
+ env/
101
+ venv/
102
+ ENV/
103
+ env.bak/
104
+ venv.bak/
105
+
106
+ # Spyder project settings
107
+ .spyderproject
108
+ .spyproject
109
+
110
+ # Rope project settings
111
+ .ropeproject
112
+
113
+ # mkdocs documentation
114
+ /site
115
+
116
+ # mypy
117
+ .mypy_cache/
118
+ .dmypy.json
119
+ dmypy.json
120
+
121
+ # Pyre type checker
122
+ .pyre/
123
+
124
+ # vscode
125
+ .vs
126
+ .vscode
127
+
128
+ # Pycharm
129
+ .idea
130
+
131
+ # TF code
132
+ tensorflow_code
133
+
134
+ # Models
135
+ proc_data
136
+
137
+ # examples
138
+ runs
139
+ /runs_old
140
+ /wandb
141
+ /examples/runs
142
+ /examples/**/*.args
143
+ /examples/rag/sweep
144
+
145
+ # data
146
+ /data
147
+ serialization_dir
148
+
149
+ # emacs
150
+ *.*~
151
+ debug.env
152
+
153
+ # vim
154
+ .*.swp
155
+
156
+ #ctags
157
+ tags
158
+
159
+ # pre-commit
160
+ .pre-commit*
161
+
162
+ # .lock
163
+ *.lock
164
+
165
+ # DS_Store (MacOS)
166
+ .DS_Store
167
+
168
+ # ruff
169
+ .ruff_cache
app.py CHANGED
@@ -1,22 +1,20 @@
1
- from git import Repo
2
- import gradio as gr
 
3
 
4
- from medusa_training import run, DEFAULT_TRAINING_ARGS
5
 
6
- # Clone the medusa repo locally
7
- print("Cloning the medusa repo locally...")
8
- Repo.clone_from("https://github.com/FasterDecoding/Medusa.git", "medusa")
9
- print("Cloning the vicuna data locally...")
10
- Repo.clone_from("https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered", "data")
11
- print("Done")
12
 
13
 
14
  DESCRIPTION = """
15
  The steps to create [medusa](https://sites.google.com/view/medusa-llm) heads are the following:
16
 
17
  1. Input a public model id from the Hub
18
- 2. Click "Submit"
19
- 3. That's it! You'll get feedback if it works or not, and if it worked, you'll get the name of the new repo πŸ”₯
 
20
  """
21
 
22
  title="Create LLM medusa heads in a new repo 🐍"
@@ -28,8 +26,12 @@ with gr.Blocks(title=title) as demo:
28
  with gr.Row() as r:
29
  with gr.Column() as c:
30
  model_id = gr.Text(max_lines=1, label="model_id")
 
 
 
 
31
  with gr.Accordion("Training arguments (advanced)", open=False):
32
- training_args = gr.Textbox(DEFAULT_TRAINING_ARGS, interactive=True, lines=14, label="training_args")
33
  with gr.Row() as c:
34
  clean = gr.ClearButton()
35
  submit = gr.Button("Submit", variant="primary")
@@ -37,6 +39,6 @@ with gr.Blocks(title=title) as demo:
37
  with gr.Column() as d:
38
  status_box = gr.Markdown()
39
 
40
- submit.click(run, inputs=[model_id, training_args], outputs=status_box, concurrency_limit=1)
41
 
42
  demo.queue(max_size=10).launch(show_api=True)
 
1
+ """
2
+ Holds the gradio app itself
3
+ """
4
 
5
+ import gradio as gr
6
 
7
+ from src.train_workflow import run, DEFAULT_TRAINING_ARGS
8
+ from src.calibration_datasets import CalibrationDataset
 
 
 
 
9
 
10
 
11
  DESCRIPTION = """
12
  The steps to create [medusa](https://sites.google.com/view/medusa-llm) heads are the following:
13
 
14
  1. Input a public model id from the Hub
15
+ 2. Select a dataset to train the medusa heads on. The dataset should be representative of the downstream use case.
16
+ 3. Click "Submit"
17
+ 4. That's it! You'll get feedback if it works or not, and if it worked, you'll get the name of the new repo πŸ”₯
18
  """
19
 
20
  title="Create LLM medusa heads in a new repo 🐍"
 
26
  with gr.Row() as r:
27
  with gr.Column() as c:
28
  model_id = gr.Text(max_lines=1, label="model_id")
29
+ dataset_names = [
30
+ cls.dataset for cls in CalibrationDataset.__subclasses__()
31
+ ]
32
+ dataset = gr.Dropdown(dataset_names, label="dataset")
33
  with gr.Accordion("Training arguments (advanced)", open=False):
34
+ training_args = gr.Textbox(DEFAULT_TRAINING_ARGS, interactive=True, lines=20, label="training_args")
35
  with gr.Row() as c:
36
  clean = gr.ClearButton()
37
  submit = gr.Button("Submit", variant="primary")
 
39
  with gr.Column() as d:
40
  status_box = gr.Markdown()
41
 
42
+ submit.click(run, inputs=[model_id, training_args, dataset], outputs=status_box, concurrency_limit=1)
43
 
44
  demo.queue(max_size=10).launch(show_api=True)
medusa_heads_medusa_TinyLlama-1.1B-Chat-v1.0/config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "base_model_name_or_path": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
3
+ "medusa_num_heads": 3,
4
+ "medusa_num_layers": 1,
5
+ "transformers_version": "4.37.0.dev0"
6
+ }
requirements.txt CHANGED
@@ -1,2 +1 @@
1
  medusa-llm[train]
2
- gitpython
 
1
  medusa-llm[train]
 
src/calibration_datasets.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prepares the datasets for calibration. Original code gently shared by TheBloke"""
2
+
3
+ from abc import ABC
4
+ import time
5
+ from typing import Dict, List, Optional
6
+ from datasets import load_dataset, Dataset
7
+ from transformers import PreTrainedTokenizerBase
8
+
9
+
10
+ class CalibrationDataset(ABC):
11
+ tokenizer: Optional[PreTrainedTokenizerBase] = None
12
+ num_samples: int = 128
13
+ seqlen: int = 4096
14
+ dataset_config: dict
15
+ dataset: str
16
+ dataset_name: str
17
+ dataset_limit: int = int(1e7)
18
+
19
+ # Defines the field to extract from the HF dataset
20
+ # If specified, just this field will be returned, and no transformation will be done.
21
+ dataset_field: Optional[str] = None
22
+
23
+ # Define the default parameters for a dataset which requires a transformation
24
+ # Only used if dataset_field is None.
25
+ # The fields to extract from the original dataset
26
+ transform_fields: List[str] = []
27
+
28
+ # A format string describing how the fields should be joined
29
+ # Can use {field1}, {field2}, etc. as placeholders for the field names
30
+ # Or can use actual names, eg "{input} {output}"
31
+ transform_join: str = "{field1} {field2}"
32
+
33
+ # Optional override for the dataset URL
34
+ # By default this is automatically derived from the dataset name and config
35
+ dataset_url: Optional[str] = None
36
+
37
+ data: Optional[Dataset] = None
38
+ samples: List[str] = []
39
+ tokenized_samples: List[Dict[str, str]] = {}
40
+
41
+ randomize: bool = False
42
+ randomize_seed: int = 42
43
+
44
+ def __init__(
45
+ self,
46
+ num_samples: int = 128,
47
+ seqlen: int = 4096,
48
+ tokenizer: Optional[PreTrainedTokenizerBase] = None
49
+ ):
50
+ self.num_samples = num_samples
51
+ self.seqlen = seqlen
52
+ self.tokenizer = tokenizer
53
+
54
+ @classmethod
55
+ def get_dataset(cls, dataset_name, **kwargs):
56
+ for subclass in cls.__subclasses__():
57
+ if hasattr(subclass, "dataset") and subclass.dataset == dataset_name:
58
+ return subclass(**kwargs)
59
+
60
+ raise ValueError(f"No dataset class found for name: {dataset_name}")
61
+
62
+ def tokenize_dataset(self, samples: Optional[List[str]] = None) -> List[Dict[str, int]]:
63
+ """
64
+ Tokenize the dataset and return a list of tokens of `seqlen` length
65
+
66
+ First tokenize the List[str] of samples, as a batch.
67
+
68
+ Then flatten the batch, and split it into `num_samples` rows of `seqlen` length.
69
+ """
70
+ if not self.tokenizer:
71
+ raise ValueError("No tokenizer provided to tokenize_dataset()")
72
+ else:
73
+ if not samples:
74
+ if not self.samples:
75
+ self.get_samples()
76
+ samples = self.samples
77
+
78
+ print(f"Tokenizing {self.dataset_name} of length {len(samples)}")
79
+
80
+ start_time = time.time()
81
+ # Tokenize the list of samples. We don't use return_tensors="pt",
82
+ # as that requires the samples to be the same length, or padding to be used.
83
+ tokenized = self.tokenizer(samples)
84
+
85
+ # Output of tokenizer will be:
86
+ # {"input_ids": [[1,2,3], [4,5], [6,7]], "attention_mask": [[1,1,1], [1,1], [1,1]]}
87
+ # Flatten that so as to concatenate the samples into a single input_mask and attention_mask
88
+ flattened = {
89
+ key: [
90
+ item for sublist in value
91
+ for item in sublist
92
+ ]
93
+ for key, value in tokenized.items()
94
+ }
95
+ print(
96
+ f"Tokenized length: {len(flattened['input_ids'])} tokens."
97
+ )
98
+
99
+ # Slice our single input_mask list into num_samples samples of seqlen length
100
+ tokenized_samples = []
101
+ for i in range(0, self.num_samples * self.seqlen, self.seqlen):
102
+ if i + self.seqlen >= len(flattened["input_ids"]):
103
+ break
104
+ sample = {
105
+ "input_ids": flattened["input_ids"][i:i + self.seqlen],
106
+ "attention_mask": flattened["attention_mask"][i:i + self.seqlen]
107
+ }
108
+ tokenized_samples.append(sample)
109
+
110
+ print(
111
+ f"Return {len(tokenized_samples)} samples of {self.seqlen} length. "
112
+ f"Time taken: {time.time() - start_time:.2f}s."
113
+ )
114
+ self.tokenized_samples = tokenized_samples
115
+ return self.tokenized_samples
116
+
117
+ def get_hf_dataset(
118
+ self,
119
+ path: str,
120
+ limit: Optional[int] = None,
121
+ **kwargs
122
+ ) -> Dataset:
123
+ """Load the Hugging Face dataset at `path`, using the provided kwargs."""
124
+
125
+ print(f"Loading HF dataset {path} with params: {kwargs}")
126
+ data: Dataset = load_dataset(path=path, **kwargs)
127
+
128
+ limit = limit and min(limit, len(data)) or len(data)
129
+ return data.select(range(limit))
130
+
131
+ @staticmethod
132
+ def list_with_nls(samples: List[str]) -> List[str]:
133
+ """
134
+ Return a List[str] with each sample ending in a newline.
135
+
136
+ Also filters the list by stripping, then removing any empty samples.
137
+ """
138
+ return [
139
+ x.rstrip() + '\n'
140
+ for x in samples
141
+ if x and len(x.strip()) > 0
142
+ ]
143
+
144
+ def get_samples(self) -> List[str]:
145
+ """
146
+ Return a list of samples for the dataset.
147
+
148
+ If the subclass implements `dataset_field`, this is used to filter the HF Dataset.
149
+
150
+ Otherwise, the subclass must implement `process_samples()`, for custom filtering.
151
+
152
+ Samples are returned as a List[str], each ending in a newline.
153
+ """
154
+ # Load HF dataset. Subclasses provide HF dataset details in `dataset_config`
155
+ if not self.data:
156
+ self.data = self.get_hf_dataset(**self.dataset_config, limit=self.dataset_limit)
157
+
158
+ if not self.samples:
159
+ if hasattr(self, "dataset_field") and self.dataset_field:
160
+ samples = self.data[self.dataset_field]
161
+ else:
162
+ try:
163
+ samples = self.process_samples()
164
+ except NotImplementedError:
165
+ raise ValueError(
166
+ f"No dataset field specified for class {self.__class__}, "
167
+ f"and process_samples() method not defined."
168
+ )
169
+ if self.randomize:
170
+ import random
171
+ random.seed(self.randomize_seed)
172
+ random.shuffle(samples)
173
+ self.samples = self.list_with_nls(samples)
174
+ return self.samples
175
+
176
+ def process_samples(self) -> List[str]:
177
+ if not self.transform_fields or not isinstance(self.transform_fields, list):
178
+ raise ValueError("transform_fields must be a List[str], defined in the subclass")
179
+
180
+ if not self.transform_join or not isinstance(self.transform_join, str):
181
+ raise ValueError("transform_fields must be a str defined in the subclass")
182
+
183
+ def transform_sample(sample):
184
+ field_values = {field: sample[field] for field in self.transform_fields}
185
+ # We support both:
186
+ # generic numbered fields: "{field1} {field2}"
187
+ # and named fields: "{input} {output}"
188
+ # Creating a combined dictionary to handle both specific field names and generic placeholders
189
+ combined_dict = {**field_values, **{f'field{i+1}': field for i, field in enumerate(field_values.values())}}
190
+ output = self.transform_join.format_map(combined_dict)
191
+ return {"output": output}
192
+
193
+ return self.data.map(transform_sample)["output"]
194
+
195
+ def generate_checksum(self) -> str:
196
+ # Create a sha256sum checksum of the joined samples
197
+ # Can be used to confirm that code updates haven't changed the output
198
+ import hashlib
199
+ samples = self.get_samples()
200
+ combined_samples = ''.join(samples)
201
+ checksum = hashlib.sha256(combined_samples.encode()).hexdigest()
202
+ return checksum
203
+
204
+ @classmethod
205
+ def get_dataset_url(cls) -> str:
206
+ """Return the Hugging Face dataset URL for this dataset."""
207
+ if hasattr(cls, "dataset_url") and cls.dataset_url:
208
+ return cls.dataset_url
209
+ else:
210
+ return "https://huggingface.co/datasets/{}/viewer/{}".format(
211
+ cls.dataset_config["path"],
212
+ cls.dataset_config.get("name", "")
213
+ )
214
+
215
+
216
+ class WikitextDataset(CalibrationDataset):
217
+ dataset = "wikitext"
218
+ dataset_config = {
219
+ "path": "wikitext",
220
+ "name": "wikitext-2-raw-v1",
221
+ "split": "train"
222
+ }
223
+ dataset_name = "Wikitext2 Full"
224
+
225
+ def process_samples(self) -> List[str]:
226
+ return [
227
+ "\n" if len(item) == 0 else item
228
+ for item in self.data["text"]
229
+ ]
230
+
231
+
232
+ class C4Dataset(CalibrationDataset):
233
+ dataset = "c4"
234
+ dataset_field = "text"
235
+ dataset_config = {
236
+ "path": "allenai/c4",
237
+ "data_files": {
238
+ "train": "en/c4-train.00000-of-01024.json.gz"
239
+ },
240
+ "split": "train"
241
+ }
242
+ dataset_name = "C4"
243
+
244
+
245
+ class ThaiDataset(CalibrationDataset):
246
+ dataset = "thai"
247
+ dataset_field = "text"
248
+ dataset_config = {
249
+ "path": "pbwt/all-thai",
250
+ "data_files": {
251
+ "train": "data/train-00000-of-00047-985fbaed08d034cf.parquet"
252
+ },
253
+ "split": "train"
254
+ }
255
+ dataset_name = "All Thai"
256
+
257
+
258
+ class MovieScriptDataset(CalibrationDataset):
259
+ dataset = "movie-scripts"
260
+ dataset_field = "full_script"
261
+ dataset_config = {
262
+ "path": "jondurbin/cinematika-v0.1",
263
+ "data_files": { "train": "full_script.parquet" },
264
+ "split": "train"
265
+ }
266
+ dataset_name = "Cinematika Full Scripts"
267
+
268
+
269
+ class JapaneseEnglishDataset(CalibrationDataset):
270
+ dataset = "japanese-english"
271
+ dataset_config = {
272
+ "path": "augmxnt/shisa-en-ja-dpo-v1",
273
+ "split": "train"
274
+ }
275
+ dataset_name = "Shisa English Japanese DPO"
276
+ randomize = True
277
+
278
+ def process_samples(self) -> List[str]:
279
+ def transform_samples(sample):
280
+ prompt = sample["prompt"]
281
+ chosen = sample["chosen"]
282
+ # prompt example: "[INST] <<SYS>>\nYou are a helpful, unbiased, uncensored assistant.\n<</SYS>>\n\nWhat are cardigans made of? Leather or wood? [/INST]"
283
+
284
+ try:
285
+ part1 = prompt.split('\n<</SYS>>\n\n')[1]
286
+ extracted_text = part1.split(' [/INST]')[0]
287
+ except Exception as e:
288
+ print(f"Error extracting text from prompt '{prompt}': {e}")
289
+ raise
290
+
291
+ prompt = extracted_text
292
+
293
+ return {"output": f"{prompt} {chosen}"}
294
+
295
+ return self.data.map(transform_samples)["output"]
296
+
297
+
298
+ class PortugueseDataset(CalibrationDataset):
299
+ dataset = "portuguese"
300
+ dataset_config = {
301
+ "path": "adalbertojunior/portuguese_orca",
302
+ "split": "train"
303
+ }
304
+ dataset_name = "Portuguese Orca"
305
+ transform_fields = [ "question", "response" ]
306
+
307
+
308
+ class MathsDataset(CalibrationDataset):
309
+ dataset = "maths"
310
+ dataset_config = {
311
+ "path": "andersonbcdefg/math",
312
+ "split": "train"
313
+ }
314
+ dataset_name = "CamelAI Math"
315
+ transform_fields = [ "message_1", "message_2" ]
316
+
317
+
318
+ class MedicalDataset(CalibrationDataset):
319
+ dataset = "medical"
320
+ dataset_config = {
321
+ "path": "medalpaca/medical_meadow_wikidoc",
322
+ "split": "train"
323
+ }
324
+ dataset_name = "Medical Medaow WikiDoc"
325
+ transform_fields = [ "input", "output" ]
326
+
327
+
328
+ class OpenInstructDataset(CalibrationDataset):
329
+ dataset = "open-instruct"
330
+ dataset_config = {
331
+ "path": "VMware/open-instruct",
332
+ "split": "train"
333
+ }
334
+ dataset_name = "VMware Open Instruct"
335
+ transform_fields = [ "instruction", "response" ]
336
+
337
+
338
+ class KoreanDataset(CalibrationDataset):
339
+ dataset = "korean"
340
+ dataset_config = {
341
+ "path": "beomi/KoAlpaca-v1.1a",
342
+ "split": "train"
343
+ }
344
+ dataset_name = "Korean Alpaca"
345
+ transform_fields = [ "instruction", "output" ]
346
+
347
+
348
+ class CodeDataset(CalibrationDataset):
349
+ dataset = "code"
350
+ dataset_field = "output"
351
+ dataset_config = {
352
+ "path": "nickrosh/Evol-Instruct-Code-80k-v1",
353
+ "split": "train"
354
+ }
355
+ dataset_name = "Evol Instruct Code"
356
+
357
+
358
+ class MultiLanguageDataset(CalibrationDataset):
359
+ dataset = "multi-language"
360
+ dataset_field = "text"
361
+ dataset_config = {
362
+ "path": "papluca/language-identification",
363
+ "split": "train"
364
+ }
365
+ dataset_name = "Language Identification"
366
+
367
+
368
+ class RussianDataset(CalibrationDataset):
369
+ dataset = "russian"
370
+ dataset_config = {
371
+ "path": "Den4ikAI/russian_instructions_2",
372
+ "split": "train"
373
+ }
374
+ dataset_name = "Russian Instructions 2"
375
+ transform_fields = [ "question", "answer" ]
376
+
377
+
378
+ class DutchDataset(CalibrationDataset):
379
+ dataset = "dutch"
380
+ dataset_config = {
381
+ "path": "BramVanroy/dolly-15k-dutch",
382
+ "split": "train"
383
+ }
384
+ dataset_name = "Dolly 15K Dutch"
385
+ transform_fields = [ "instruction", "context", "response" ]
386
+ transform_join = "{field1} {field2} {field3}"
387
+
388
+
389
+ class VietnameseChineseDataset(CalibrationDataset):
390
+ dataset = "vietnamesechinese"
391
+ dataset_config = {
392
+ "path": "nRuaif/Vietnamese_x_Alpaca",
393
+ "split": "train"
394
+ }
395
+ dataset_name = "Vietnamese and Chinese"
396
+
397
+ def get_dataset_url(self) -> None:
398
+ return None
399
+
400
+ def process_samples(self) -> List[str]:
401
+ samples = self.data["output"]
402
+ chinese_samples = CalibrationDataset.get_dataset("chinese").get_samples()
403
+
404
+ joined_list = samples + chinese_samples
405
+
406
+ import random
407
+ random.shuffle(joined_list)
408
+
409
+ return joined_list[:self.dataset_limit]
410
+
411
+
412
+ class VietnameseDataset(CalibrationDataset):
413
+ dataset = "vietnamese"
414
+ dataset_field = "output"
415
+ dataset_config = {
416
+ "path": "nRuaif/Vietnamese_x_Alpaca",
417
+ "split": "train"
418
+ }
419
+ dataset_name = "Alpaca Vietnamese"
420
+
421
+
422
+ class ChineseDataset(CalibrationDataset):
423
+ dataset = "chinese"
424
+ dataset_config = {
425
+ "path": "TigerResearch/tigerbot-alpaca-zh-0.5m",
426
+ "split": "train"
427
+ }
428
+ dataset_name = "Tiger Alpaca ZH"
429
+ transform_fields = [ "instruction", "input", "output" ]
430
+ transform_join = "{field1} {field2} {field3}"
431
+
432
+
433
+ class LatinEnglishDataset(CalibrationDataset):
434
+ dataset = "latin-english"
435
+ dataset_config = {
436
+ "path": "grosenthal/latin_english_parallel",
437
+ "split": "train"
438
+ }
439
+ dataset_name = "Latin English Parallel"
440
+ transform_fields = [ "la", "en" ]
441
+ transform_join = "{field1}\n{field2}"
442
+
443
+
444
+ class PolishDataset(CalibrationDataset):
445
+ dataset = "polish"
446
+ dataset_field = "content"
447
+ dataset_config = {
448
+ "path": "WiktorS/polish-news",
449
+ "split": "train"
450
+ }
451
+ dataset_name = "Polish News"
452
+
453
+
454
+ class JapaneseDataset(CalibrationDataset):
455
+ dataset = "japanese"
456
+ dataset_field = "output"
457
+ dataset_config = {
458
+ "path": "fujiki/japanese_alpaca_data",
459
+ "split": "train"
460
+ }
461
+ dataset_name = "Alpaca Japanese"
462
+
463
+
464
+ class SpanishDataset(CalibrationDataset):
465
+ dataset = "spanish"
466
+ dataset_field = "output"
467
+ dataset_config = {
468
+ "path": "bertin-project/alpaca-spanish",
469
+ "split": "train"
470
+ }
471
+ dataset_name = "Alpaca Spanish"
472
+
473
+
474
+ class GermanDataset(CalibrationDataset):
475
+ dataset = "german"
476
+ dataset_config = {
477
+ "path": "deepset/germanquad",
478
+ "split": "train"
479
+ }
480
+ dataset_name = "German Quad"
481
+
482
+ def process_samples(self) -> List[str]:
483
+ def transform_samples(sample):
484
+ split_context = sample["context"].split("===")
485
+ if len(split_context) >= 3:
486
+ trans_context = split_context[2]
487
+ else:
488
+ trans_context = sample["context"]
489
+ return {"output": trans_context.strip()}
490
+
491
+ return self.data.map(transform_samples)["output"]
492
+
493
+
494
+ class FrenchDataset(CalibrationDataset):
495
+ dataset = "french"
496
+ dataset_field = "text"
497
+ dataset_config = {
498
+ "path": "Kant1/French_Wikipedia_articles",
499
+ "data_files": { "wiki_00.txt" },
500
+ "split": "train"
501
+ }
502
+ dataset_name = "French Wikipedia Articles"
503
+
504
+
505
+ def validate_dataset(dataset_name: str, **kwargs):
506
+ for cls in CalibrationDataset.__subclasses__():
507
+ if hasattr(cls, "dataset") and cls.dataset == dataset_name:
508
+ return True
509
+ return False
510
+
511
+ # FIXME: a temp function put in for AutoAWQ, pending full refactor where it won't be necessary
512
+ def get_dataset_url(dataset_name: str):
513
+ for cls in CalibrationDataset.__subclasses__():
514
+ if hasattr(cls, "dataset") and cls.dataset == dataset_name:
515
+ return cls.get_dataset_url()
516
+ raise ValueError(f"No dataset class found for name: {dataset_name}")
517
+
518
+ def get_dataset_name(dataset_name: str):
519
+ for cls in CalibrationDataset.__subclasses__():
520
+ if hasattr(cls, "dataset") and cls.dataset == dataset_name:
521
+ return cls.dataset_name
522
+ raise ValueError(f"No dataset class found for name: {dataset_name}")
523
+
524
+ def test_datasets(datasets: Optional[List[str]] = None, checksum_only=False):
525
+ import sys
526
+ from transformers import AutoTokenizer
527
+ try:
528
+ failed = []
529
+ for cls in CalibrationDataset.__subclasses__():
530
+ if not hasattr(cls, "dataset") or not cls.dataset:
531
+ failed.append(cls.__name__)
532
+ if failed:
533
+ print(f"The following classes have no 'dataset' attribute: {failed}")
534
+ sys.exit(-1)
535
+ else:
536
+ print()(f"All classes have 'dataset' attribute.")
537
+
538
+ print(f"Enumerating CalibrationDataset classes")
539
+ classes = CalibrationDataset.__subclasses__()
540
+ dataset_names = [
541
+ cls.dataset
542
+ for cls in classes
543
+ if cls.dataset and (not datasets or cls.dataset in datasets)
544
+ ]
545
+
546
+ print(f"Found {len(classes)} total dataset classes: {[c.dataset for c in classes]}")
547
+ if datasets:
548
+ print(f"Will test {len(dataset_names)} datasets: {dataset_names}")
549
+
550
+ print(f"Starting test: loading Llama-2 tokenizer")
551
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_fast=True)
552
+
553
+ for name in dataset_names:
554
+ print(f"{name} test: loading dataset.")
555
+ dataset = CalibrationDataset.get_dataset(name, tokenizer=tokenizer)
556
+ if not checksum_only:
557
+ print(f"{name} test: running tokenize_dataset.")
558
+ toks = dataset.tokenize_dataset()
559
+ print(f"{name} test: getting dataset_url.")
560
+ url = dataset.get_dataset_url()
561
+ print(f"{name} - randomized? {dataset.randomize}")
562
+ print(
563
+ f"{name} - result: cls.data: length: {len(dataset.data)}, "
564
+ f"first row length: {len(dataset.data[0])}, "
565
+ f"first row data: '{dataset.data[0]}'."
566
+ )
567
+ print(
568
+ f"{name} - result: cls.samples: length: {len(dataset.samples)}, "
569
+ f"first row length: {len(dataset.samples[0])}, "
570
+ f"first row sample: '{dataset.samples[0]}'."
571
+ )
572
+ print(
573
+ f"{name} - result: tokenize_dataset result: length: {len(toks)}, "
574
+ f"length first row input_ids: {len(toks[0]['input_ids'])}."
575
+ )
576
+ print(
577
+ f"{name} - result: dataset_url: {url}"
578
+ )
579
+ checksum = dataset.generate_checksum()
580
+ print(
581
+ f"{name} - result: sha256 checksum: {checksum}"
582
+ )
583
+
584
+ except KeyboardInterrupt:
585
+ print("Test aborted")
586
+
587
+ except Exception as e:
588
+ print(
589
+ f"Received an exception during test. Test failed. "
590
+ f"Exception: {e}"
591
+ )
592
+ raise
593
+
594
+
595
+ if __name__ == "__main__":
596
+ import argparse
597
+
598
+ parser = argparse.ArgumentParser(description="Test calibration datasets")
599
+ parser.add_argument("--datasets", "-d", "-n", nargs="*", type=str, help="Dataset(s) to check; default is all")
600
+ parser.add_argument("--checksum_only", "-co", action="store_true", help="Only ouput the checksums for the datasets")
601
+ args = parser.parse_args()
602
+
603
+ test_datasets(args.datasets, checksum_only=args.checksum_only)
src/medusa_training_script.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hold the training script for the medusa model.
3
+
4
+ Adapted from the original code here: https://github.com/FasterDecoding/Medusa/blob/main/medusa/train/train.py
5
+ """
6
+
7
+ import os
8
+ from dataclasses import dataclass, field
9
+ import pathlib
10
+ from typing import Dict, Optional
11
+
12
+ import torch
13
+ from torch.utils.data import Dataset
14
+ import transformers
15
+ from transformers import Trainer, BitsAndBytesConfig
16
+ from transformers.trainer_pt_utils import LabelSmoother
17
+ from torch.nn import CrossEntropyLoss
18
+ from medusa.model.medusa_model import MedusaModel, MedusaConfig
19
+
20
+ from calibration_datasets import CalibrationDataset
21
+
22
+
23
+ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
24
+
25
+
26
+ # Customized for training Medusa heads
27
+ class CustomizedTrainer(Trainer):
28
+ def compute_loss(self, model, inputs, return_outputs=False):
29
+ """
30
+ Compute the training loss for the model.
31
+
32
+ Args:
33
+ model (torch.nn.Module): The model for which to compute the loss.
34
+ inputs (dict): The input data, including input IDs, attention mask, and labels.
35
+ return_outputs (bool): Whether to return model outputs along with the loss.
36
+
37
+ Returns:
38
+ Union[float, Tuple[float, torch.Tensor]]: The computed loss, optionally with model outputs.
39
+ """
40
+ # DDP will give us model.module
41
+ if hasattr(model, "module"):
42
+ medusa = model.module.medusa
43
+ else:
44
+ medusa = model.medusa
45
+
46
+ logits = model(
47
+ input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
48
+ )
49
+ labels = inputs["labels"]
50
+ # Shift so that tokens < n predict n
51
+ loss = 0
52
+ loss_fct = CrossEntropyLoss()
53
+ log = {}
54
+ for i in range(medusa):
55
+ medusa_logits = logits[i, :, : -(2 + i)].contiguous()
56
+ medusa_labels = labels[..., 2 + i :].contiguous()
57
+ medusa_logits = medusa_logits.view(-1, logits.shape[-1])
58
+ medusa_labels = medusa_labels.view(-1)
59
+ medusa_labels = medusa_labels.to(medusa_logits.device)
60
+ loss_i = loss_fct(medusa_logits, medusa_labels)
61
+ loss += loss_i
62
+ not_ignore = medusa_labels.ne(IGNORE_TOKEN_ID)
63
+ medusa_labels = medusa_labels[not_ignore]
64
+
65
+ # Add top-k accuracy
66
+ for k in range(1, 6):
67
+ _, topk = medusa_logits.topk(k, dim=-1)
68
+ topk = topk[not_ignore]
69
+ correct = topk.eq(medusa_labels.unsqueeze(-1)).any(-1)
70
+ log[f"medusa{i}_top{k}"] = correct.float().mean().item()
71
+
72
+ log[f"medusa{i}_loss"] = loss_i.item()
73
+ self.log(log)
74
+ return (loss, logits) if return_outputs else loss
75
+
76
+
77
+ @dataclass
78
+ class ModelArguments:
79
+ model_name_or_path: Optional[str] = field()
80
+ load_in_4bit: bool = field(
81
+ default=False,
82
+ metadata={"help": "Load in 4 bit."},
83
+ )
84
+ load_in_8bit: bool = field(
85
+ default=False,
86
+ metadata={"help": "Load in 8 bit."},
87
+ )
88
+
89
+
90
+ @dataclass
91
+ class DataArguments:
92
+ dataset: str = field(
93
+ metadata={"help": "One of the datasets names in a CalibrationDataset subclass."},
94
+ )
95
+
96
+
97
+ @dataclass
98
+ class TrainingArguments(transformers.TrainingArguments):
99
+ cache_dir: Optional[str] = field(default=None)
100
+ optim: str = field(default="adamw_torch")
101
+ model_max_length: int = field(
102
+ default=2048,
103
+ metadata={
104
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
105
+ },
106
+ )
107
+ medusa_num_heads: int = field(
108
+ default=1,
109
+ metadata={"help": "Number of Medusa heads."},
110
+ )
111
+ medusa_num_layers: int = field(
112
+ default=1,
113
+ metadata={"help": "Number of layers for each Medusa head."},
114
+ )
115
+
116
+
117
+ local_rank = None
118
+
119
+
120
+ def rank0_print(*args):
121
+ if local_rank == 0:
122
+ print(*args)
123
+
124
+
125
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
126
+ """
127
+ Save the model's state dictionary to a specified directory.
128
+
129
+ Args:
130
+ trainer (transformers.Trainer): The Hugging Face Trainer object.
131
+ output_dir (str): The directory where the model state dictionary will be saved.
132
+ """
133
+ state_dict = trainer.model.state_dict()
134
+ if trainer.args.should_save:
135
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
136
+ del state_dict
137
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
138
+
139
+
140
+ class SupervisedDataset(Dataset):
141
+ """Dataset for supervised fine-tuning.
142
+
143
+ Args:
144
+ dataset (str): One of the datasets names in a CalibrationDataset subclass.
145
+ tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
146
+ """
147
+
148
+ def __init__(self, dataset, tokenizer: transformers.PreTrainedTokenizer):
149
+ super(SupervisedDataset, self).__init__()
150
+
151
+ rank0_print("Formatting inputs...")
152
+ dataset_classes = CalibrationDataset.__subclasses__()
153
+ for dataset_class in dataset_classes:
154
+ if dataset_class.dataset == dataset:
155
+ dataset = dataset_class(num_samples=int(1e6), seqlen=tokenizer.model_max_length, tokenizer=tokenizer)
156
+ break
157
+ tokenized = dataset.tokenize_dataset()
158
+ self.input_ids = torch.tensor([data["input_ids"] for data in tokenized], dtype=torch.long)
159
+ self.attention_mask = torch.tensor([data["attention_mask"] for data in tokenized], dtype=torch.long)
160
+
161
+ def __len__(self):
162
+ return self.input_ids.shape[0]
163
+
164
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
165
+ return dict(
166
+ input_ids=self.input_ids[i],
167
+ labels=self.input_ids[i],
168
+ attention_mask=self.attention_mask[i],
169
+ )
170
+
171
+
172
+ def train():
173
+ global local_rank
174
+
175
+ parser = transformers.HfArgumentParser(
176
+ (ModelArguments, DataArguments, TrainingArguments)
177
+ )
178
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
179
+ local_rank = training_args.local_rank
180
+
181
+ config = transformers.AutoConfig.from_pretrained(
182
+ model_args.model_name_or_path,
183
+ cache_dir=training_args.cache_dir,
184
+ )
185
+ config.use_cache = False
186
+
187
+ quantization_config = BitsAndBytesConfig(
188
+ load_in_4bit=True,
189
+ bnb_4bit_compute_dtype=torch.bfloat16,
190
+ bnb_4bit_use_double_quant=True,
191
+ bnb_4bit_quant_type="nf4",
192
+ )
193
+
194
+ # Load model and tokenizer
195
+ model = transformers.AutoModelForCausalLM.from_pretrained(
196
+ model_args.model_name_or_path,
197
+ config=config,
198
+ cache_dir=training_args.cache_dir,
199
+ low_cpu_mem_usage=True,
200
+ torch_dtype=torch.bfloat16,
201
+ quantization_config=quantization_config if model_args.load_in_4bit else None,
202
+ load_in_4bit=model_args.load_in_4bit,
203
+ load_in_8bit=model_args.load_in_8bit,
204
+ )
205
+
206
+ # Freeze the base model
207
+ for param in model.base_model.parameters():
208
+ param.requires_grad = False
209
+
210
+ # Add Medusa heads
211
+ medusa_lm_head = MedusaModel(
212
+ model,
213
+ medusa_num_heads=training_args.medusa_num_heads,
214
+ medusa_num_layers=training_args.medusa_num_layers,
215
+ base_model_name_or_path=model_args.model_name_or_path,
216
+ )
217
+
218
+ # Format output dir
219
+ training_args.output_dir = f"{training_args.output_dir}_medusa_{model_args.model_name_or_path.split('/')[-1]}"
220
+
221
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
222
+ model_args.model_name_or_path,
223
+ cache_dir=training_args.cache_dir,
224
+ model_max_length=training_args.model_max_length,
225
+ padding_side="right",
226
+ use_fast=False,
227
+ )
228
+ tokenizer.pad_token = tokenizer.unk_token
229
+
230
+ # Load data
231
+ data_module = {"train_dataset": SupervisedDataset(data_args.dataset, tokenizer), "eval_dataset": None}
232
+
233
+
234
+ # Generate Medusa config for pushing to HF hub
235
+ medusa_config = MedusaConfig(
236
+ medusa_num_heads=training_args.medusa_num_heads,
237
+ medusa_num_layers=training_args.medusa_num_layers,
238
+ base_model_name_or_path=model_args.model_name_or_path,
239
+ )
240
+
241
+ # Save Medusa config
242
+ medusa_config.save_pretrained(training_args.output_dir)
243
+
244
+ # Start trainner
245
+ trainer = CustomizedTrainer(
246
+ model=medusa_lm_head, tokenizer=tokenizer, args=training_args, **data_module
247
+ )
248
+
249
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
250
+ trainer.train(resume_from_checkpoint=True)
251
+ else:
252
+ trainer.train()
253
+ model.config.use_cache = True
254
+
255
+ # Save MedusaHead seperately
256
+ if hasattr(medusa_lm_head, "module"):
257
+ lm_head = medusa_lm_head.module.medusa_head
258
+ else:
259
+ lm_head = medusa_lm_head.medusa_head
260
+
261
+ # Save Medusa heads
262
+ torch.save(
263
+ lm_head.state_dict(),
264
+ os.path.join(training_args.output_dir, "medusa_lm_head.pt"),
265
+ )
266
+
267
+
268
+ if __name__ == "__main__":
269
+ train()
medusa_training.py β†’ src/train_workflow.py RENAMED
@@ -1,4 +1,6 @@
1
- import json
 
 
2
  import os
3
  import multiprocessing as mp
4
 
@@ -9,26 +11,23 @@ import torch
9
  import torch.distributed.run as distributed_run
10
 
11
  OUTPUT_DIR = "medusa_heads"
12
- MEDUSA_NUM_HEADS = 3
13
- MEDUSA_NUM_LAYERS = 1
14
- LR = 1e-3
15
 
16
  DATASET = "vicuna"
17
 
18
  # These can't be changed (e.g. they control the output path)
19
  FIXED_TRAINING_ARGS = \
20
- """medusa/medusa/train/train.py
21
  --model_name_or_path {model_id}
22
  --output_dir {output_dir}
23
  --run_name {model_id}-medusa-{dataset}
24
- --medusa_num_heads {medusa_num_heads}
25
- --medusa_num_layers {medusa_num_layers}
26
- --learning_rate {lr}
27
- --data_path data/ShareGPT_V4.3_unfiltered_cleaned_split.json"""
28
 
29
  # These can be freely changed
30
  DEFAULT_TRAINING_ARGS = \
31
- """--bf16 True
 
 
 
32
  --num_train_epochs 1
33
  --per_device_train_batch_size 64
34
  --per_device_eval_batch_size 64
@@ -40,19 +39,13 @@ DEFAULT_TRAINING_ARGS = \
40
  --lr_scheduler_type cosine
41
  --logging_steps 10
42
  --tf32 True
43
- --model_max_length 2048
44
- --lazy_preprocess True
45
- --auto_find_batch_size True"""
46
 
47
 
48
- def train_medusa_heads(model_id: str, training_args: str):
49
  all_training_args = FIXED_TRAINING_ARGS.format(
50
- model_id=model_id,
51
- output_dir=OUTPUT_DIR,
52
- dataset=DATASET,
53
- medusa_num_heads=MEDUSA_NUM_HEADS,
54
- lr=LR,
55
- medusa_num_layers=MEDUSA_NUM_LAYERS
56
  ) + "\n" + training_args
57
  all_training_arg_list = []
58
  for arg in all_training_args.split("\n"):
@@ -64,11 +57,11 @@ def train_medusa_heads(model_id: str, training_args: str):
64
  distributed_run.run(args)
65
 
66
 
67
- def run(model_id: str, training_args: str) -> str:
68
  print(f"\n\n\nNEW RUN: {model_id}")
69
  api = HfApi()
70
  model_name = model_id.split("/")[-1]
71
- repo_id = f"joaogante/{model_name}-medusa-{DATASET}"
72
 
73
  # Input validation
74
  if model_id == "":
@@ -101,7 +94,7 @@ def run(model_id: str, training_args: str) -> str:
101
 
102
  # Run the medusa heads creation
103
  try:
104
- proc = mp.Process(target=train_medusa_heads, args=(model_id, training_args))
105
  proc.start()
106
  proc.join()
107
  print("Medusa heads training process completed (it might have crashed!)")
@@ -117,7 +110,7 @@ def run(model_id: str, training_args: str) -> str:
117
  try:
118
  # Folder path from https://github.com/FasterDecoding/Medusa/blob/main/medusa/train/train.py#L399
119
  folder_path = (
120
- f"{OUTPUT_DIR}_medusa_mlp_{model_name}_medusa_{MEDUSA_NUM_HEADS}_lr_{LR}_layers_{MEDUSA_NUM_LAYERS}"
121
  )
122
  if not any([x for x in os.listdir(folder_path) if len(x) >= 3 and x[-3:] == ".pt"]):
123
  raise Exception(
 
1
+ """
2
+ Holds the interface between the gradio app and the medusa training script
3
+ """
4
  import os
5
  import multiprocessing as mp
6
 
 
11
  import torch.distributed.run as distributed_run
12
 
13
  OUTPUT_DIR = "medusa_heads"
 
 
 
14
 
15
  DATASET = "vicuna"
16
 
17
  # These can't be changed (e.g. they control the output path)
18
  FIXED_TRAINING_ARGS = \
19
+ """src/medusa_training_script.py
20
  --model_name_or_path {model_id}
21
  --output_dir {output_dir}
22
  --run_name {model_id}-medusa-{dataset}
23
+ --dataset {dataset}"""
 
 
 
24
 
25
  # These can be freely changed
26
  DEFAULT_TRAINING_ARGS = \
27
+ """--medusa_num_heads 3
28
+ --medusa_num_layers 1
29
+ --model_max_length 2048
30
+ --bf16 True
31
  --num_train_epochs 1
32
  --per_device_train_batch_size 64
33
  --per_device_eval_batch_size 64
 
39
  --lr_scheduler_type cosine
40
  --logging_steps 10
41
  --tf32 True
42
+ --auto_find_batch_size True
43
+ --learning_rate 1e-3"""
 
44
 
45
 
46
+ def train_medusa_heads(model_id: str, training_args: str, dataset: str):
47
  all_training_args = FIXED_TRAINING_ARGS.format(
48
+ model_id=model_id, output_dir=OUTPUT_DIR, dataset=dataset,
 
 
 
 
 
49
  ) + "\n" + training_args
50
  all_training_arg_list = []
51
  for arg in all_training_args.split("\n"):
 
57
  distributed_run.run(args)
58
 
59
 
60
+ def run(model_id: str, training_args: str, dataset: str) -> str:
61
  print(f"\n\n\nNEW RUN: {model_id}")
62
  api = HfApi()
63
  model_name = model_id.split("/")[-1]
64
+ repo_id = f"joaogante/{model_name}-medusa-{dataset}"
65
 
66
  # Input validation
67
  if model_id == "":
 
94
 
95
  # Run the medusa heads creation
96
  try:
97
+ proc = mp.Process(target=train_medusa_heads, args=(model_id, training_args, dataset))
98
  proc.start()
99
  proc.join()
100
  print("Medusa heads training process completed (it might have crashed!)")
 
110
  try:
111
  # Folder path from https://github.com/FasterDecoding/Medusa/blob/main/medusa/train/train.py#L399
112
  folder_path = (
113
+ f"{OUTPUT_DIR}_medusa_{model_name}"
114
  )
115
  if not any([x for x in os.listdir(folder_path) if len(x) >= 3 and x[-3:] == ".pt"]):
116
  raise Exception(