winglian commited on
Commit
8c2f3cb
1 Parent(s): b46bc02

support for replit lm

Browse files
examples/replit-3b/config-lora.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: replit/replit-code-v1-3b
2
+ base_model_config: replit/replit-code-v1-3b
3
+ trust_remote_code: true
4
+ load_in_8bit: false
5
+ datasets:
6
+ - path: vicgalle/alpaca-gpt4
7
+ type: alpaca
8
+ dataset_prepared_path: last_run_prepared
9
+ val_set_size: 0.05
10
+ adapter: lora
11
+ lora_model_dir:
12
+ sequence_len: 2048
13
+ max_packed_sequence_len:
14
+ lora_r: 8
15
+ lora_alpha: 16
16
+ lora_dropout: 0.05
17
+ lora_target_modules:
18
+ - Wqkv
19
+ - mlp_up
20
+ - mlp_down
21
+ lora_fan_in_fan_out:
22
+ wandb_project: lora-replit
23
+ wandb_watch:
24
+ wandb_run_id:
25
+ wandb_log_model:
26
+ output_dir: ./lora-replit
27
+ batch_size: 8
28
+ micro_batch_size: 1
29
+ num_epochs: 3
30
+ optimizer:
31
+ torchdistx_path:
32
+ lr_scheduler:
33
+ learning_rate: 0.00001
34
+ train_on_inputs: false
35
+ group_by_length: false
36
+ bf16: true
37
+ tf32: true
38
+ gradient_checkpointing:
39
+ early_stopping_patience:
40
+ resume_from_checkpoint:
41
+ local_rank:
42
+ logging_steps: 1
43
+ xformers_attention:
44
+ flash_attention:
45
+ gptq_groupsize:
46
+ gptq_model_v1:
47
+ warmup_steps: 20
48
+ eval_steps: 50
49
+ save_steps:
50
+ debug:
51
+ deepspeed:
52
+ weight_decay: 0
53
+ fsdp:
54
+ fsdp_config:
55
+ #special_tokens:
src/axolotl/utils/models.py CHANGED
@@ -163,11 +163,20 @@ def load_model(
163
  if not tokenizer:
164
  try:
165
  if is_llama_derived_model and "LlamaTokenizer" in globals():
166
- tokenizer = LlamaTokenizer.from_pretrained(model)
 
 
 
167
  else:
168
- tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
 
 
 
169
  except:
170
- tokenizer = AutoTokenizer.from_pretrained(base_model_config)
 
 
 
171
 
172
  logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
173
  logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
 
163
  if not tokenizer:
164
  try:
165
  if is_llama_derived_model and "LlamaTokenizer" in globals():
166
+ tokenizer = LlamaTokenizer.from_pretrained(
167
+ model,
168
+ trust_remote_code=True if cfg.trust_remote_code is True else False,
169
+ )
170
  else:
171
+ tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
172
+ model,
173
+ trust_remote_code=True if cfg.trust_remote_code is True else False,
174
+ )
175
  except:
176
+ tokenizer = AutoTokenizer.from_pretrained(
177
+ base_model_config,
178
+ trust_remote_code=True if cfg.trust_remote_code is True else False,
179
+ )
180
 
181
  logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
182
  logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")