Spaces:
Paused
Paused
Update finetune.py
Browse files- finetune.py +6 -6
finetune.py
CHANGED
@@ -17,13 +17,9 @@ from peft import (
|
|
17 |
|
18 |
HF_TOKEN = os.environ.get("TRL_TOKEN", None)
|
19 |
if HF_TOKEN:
|
20 |
-
try:
|
21 |
-
shutil.rmtree("./data/")
|
22 |
-
except:
|
23 |
-
pass
|
24 |
|
25 |
repo = Repository(
|
26 |
-
local_dir="./
|
27 |
)
|
28 |
repo.git_pull()
|
29 |
# Parameters
|
@@ -163,6 +159,8 @@ trainer = transformers.Trainer(
|
|
163 |
save_total_limit=100,
|
164 |
load_best_model_at_end=True if VAL_SET_SIZE > 0 else False,
|
165 |
ddp_find_unused_parameters=False if ddp else None,
|
|
|
|
|
166 |
|
167 |
),
|
168 |
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
@@ -175,6 +173,8 @@ model.state_dict = (
|
|
175 |
).__get__(model, type(model))
|
176 |
|
177 |
trainer.train()
|
178 |
-
trainer.push_to_hub()
|
179 |
|
180 |
model.save_pretrained(OUTPUT_DIR)
|
|
|
|
|
|
|
|
17 |
|
18 |
HF_TOKEN = os.environ.get("TRL_TOKEN", None)
|
19 |
if HF_TOKEN:
|
|
|
|
|
|
|
|
|
20 |
|
21 |
repo = Repository(
|
22 |
+
local_dir="./checkpoints/", clone_from="gustavoaq/llama_ft", use_auth_token=HF_TOKEN, repo_type="models"
|
23 |
)
|
24 |
repo.git_pull()
|
25 |
# Parameters
|
|
|
159 |
save_total_limit=100,
|
160 |
load_best_model_at_end=True if VAL_SET_SIZE > 0 else False,
|
161 |
ddp_find_unused_parameters=False if ddp else None,
|
162 |
+
push_to_hub=True,
|
163 |
+
push_to_hub_model_id='llama_ft'
|
164 |
|
165 |
),
|
166 |
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
|
|
173 |
).__get__(model, type(model))
|
174 |
|
175 |
trainer.train()
|
|
|
176 |
|
177 |
model.save_pretrained(OUTPUT_DIR)
|
178 |
+
trainer.push_to_hub()
|
179 |
+
|
180 |
+
|