aoxo commited on
Commit
9b490bf
1 Parent(s): b1547c3

added finetuning and trainer scripts

Browse files
Files changed (2) hide show
  1. finetuning.py +70 -0
  2. trainer.py +38 -0
finetuning.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from trl import SFTTrainer
3
+ from peft import LoraConfig
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig
6
+
7
+ #Check if you do not have any import issue to use the Fast Mamba Kernel
8
+ #Will (very appropriately) break before loading the weights.
9
+ import mamba_ssm
10
+
11
+ #With 4bit quants have to manually correct modeling_jamba.py on l. 1070:
12
+ #if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type:
13
+ #becoming:
14
+ #if not is_fast_path_available:
15
+
16
+ quantization_config = BitsAndBytesConfig(
17
+ load_in_4bit=True,
18
+ llm_int4_skip_modules=["mamba"] #Maybe not necessary (per axoltl) but to test.
19
+ )
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained("jamba")
22
+
23
+ dataset = load_dataset("VishnuPJ/Malayalam_CultureX_IndicCorp_SMC", split="train")
24
+ training_args = TrainingArguments(
25
+ output_dir="./results",
26
+ num_train_epochs=1,
27
+ per_device_train_batch_size=1,
28
+ gradient_accumulation_steps=4,
29
+ optim = "adamw_8bit",
30
+ max_grad_norm = 0.3,
31
+ weight_decay = 0.001,
32
+ warmup_ratio = 0.03,
33
+ gradient_checkpointing=True,
34
+ logging_dir='./logs',
35
+ logging_steps=1,
36
+ max_steps=50,
37
+ group_by_length=True,
38
+ lr_scheduler_type = "linear",
39
+ learning_rate=2e-3
40
+ )
41
+ lora_config = LoraConfig(
42
+ lora_alpha=16,
43
+ lora_dropout=0.05,
44
+ init_lora_weights=False,
45
+ r=8,
46
+ target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
47
+ task_type="CAUSAL_LM",
48
+ bias="none"
49
+ )
50
+
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ "jamba",
53
+ trust_remote_code=True,
54
+ device_map='auto',
55
+ attn_implementation="flash_attention_2",
56
+ quantization_config=quantization_config,
57
+ use_mamba_kernels=True
58
+ )
59
+
60
+ trainer = SFTTrainer(
61
+ model=model,
62
+ tokenizer=tokenizer,
63
+ args=training_args,
64
+ peft_config=lora_config,
65
+ train_dataset=dataset,
66
+ max_seq_length = 256,
67
+ dataset_text_field="quote",
68
+ )
69
+
70
+ trainer.train()
trainer.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Load model directly
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
3
+
4
+ tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-tiny-random")
5
+ model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-tiny-random", trust_remote_code=True)
6
+
7
+ from datasets import load_dataset
8
+ from trl import SFTTrainer
9
+ from peft import LoraConfig
10
+
11
+ dataset=load_dataset("rajeshradhakrishnan/malayalam_wiki")
12
+
13
+ training_args = TrainingArguments(
14
+ output_dir="./results",
15
+ num_train_epochs=3,
16
+ per_device_train_batch_size=3,
17
+ logging_dir='./logs',
18
+ logging_steps=10,
19
+ learning_rate=2e-3
20
+ )
21
+
22
+ lora_config = LoraConfig(
23
+ r=8,
24
+ target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
25
+ task_type="CAUSAL_LM",
26
+ bias="none"
27
+ )
28
+
29
+ trainer = SFTTrainer(
30
+ model=model,
31
+ tokenizer=tokenizer,
32
+ args=training_args,
33
+ peft_config=lora_config,
34
+ train_dataset=dataset["train"],
35
+ dataset_text_field="text",
36
+ )
37
+
38
+ trainer.train()