khoicrtp commited on
Commit
953210f
1 Parent(s): 5fe70fd

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +20 -4
  2. prepare.sh +4 -3
main.py CHANGED
@@ -92,14 +92,24 @@ print(dataset_data[0])
92
  with open("mitre-dataset.json", "w") as f:
93
  json.dump(dataset_data, f)
94
 
 
 
 
95
 
96
  BASE_MODEL = "decapoda-research/llama-7b-hf"
97
 
98
- model = LlamaForCausalLM.from_pretrained(
 
 
 
 
 
 
 
 
99
  BASE_MODEL,
100
- load_in_8bit=True,
101
- torch_dtype=torch.float16,
102
  device_map="auto",
 
103
  )
104
 
105
  tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
@@ -196,7 +206,6 @@ training_arguments = transformers.TrainingArguments(
196
  warmup_steps=100,
197
  max_steps=TRAIN_STEPS,
198
  learning_rate=LEARNING_RATE,
199
- fp16=True,
200
  logging_steps=10,
201
  optim="adamw_torch",
202
  evaluation_strategy="steps",
@@ -228,7 +237,14 @@ model.state_dict = (
228
  )
229
  ).__get__(model, type(model))
230
 
 
231
  model = torch.compile(model)
 
232
 
 
233
  trainer.train()
 
 
 
234
  model.save_pretrained(OUTPUT_DIR)
 
 
92
  with open("mitre-dataset.json", "w") as f:
93
  json.dump(dataset_data, f)
94
 
95
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
96
+
97
+ quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
98
 
99
  BASE_MODEL = "decapoda-research/llama-7b-hf"
100
 
101
+ device_map = {
102
+ "transformer.word_embeddings": 0,
103
+ "transformer.word_embeddings_layernorm": 0,
104
+ "lm_head": "cpu",
105
+ "transformer.h": 0,
106
+ "transformer.ln_f": 0,
107
+ }
108
+
109
+ model = AutoModelForCausalLM.from_pretrained(
110
  BASE_MODEL,
 
 
111
  device_map="auto",
112
+ quantization_config=quantization_config,
113
  )
114
 
115
  tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
 
206
  warmup_steps=100,
207
  max_steps=TRAIN_STEPS,
208
  learning_rate=LEARNING_RATE,
 
209
  logging_steps=10,
210
  optim="adamw_torch",
211
  evaluation_strategy="steps",
 
237
  )
238
  ).__get__(model, type(model))
239
 
240
+ print("Compiling model...")
241
  model = torch.compile(model)
242
+ print("Done compiling model...")
243
 
244
+ print("Training model...")
245
  trainer.train()
246
+ print("Done training model...")
247
+
248
+ print("Saving model...")
249
  model.save_pretrained(OUTPUT_DIR)
250
+ print("Done saving model...")
prepare.sh CHANGED
@@ -3,6 +3,7 @@ curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.s
3
  apt-get install -y git
4
  apt-get install -y git-lfs
5
  apt-get install -y python3-pip
6
- git clone https://huggingface.co/khoicrtp/test_model
7
- cd test_model
8
- python3 finetune_lora.py
 
 
3
  apt-get install -y git
4
  apt-get install -y git-lfs
5
  apt-get install -y python3-pip
6
+ git clone https://huggingface.co/khoicrtp/test_scratch
7
+ cd test_scratch
8
+ pip3 install -r requirements.txt --user
9
+ python3 main.py