lxe commited on
Commit
7d21166
1 Parent(s): ecf29d8

Cleanup and gc after training

Browse files
Files changed (1) hide show
  1. main.py +13 -5
main.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import argparse
3
  import random
4
  import torch
@@ -44,6 +45,10 @@ def reset_model():
44
  del model
45
  del tokenizer
46
 
 
 
 
 
47
  model = None
48
  tokenizer = None
49
  current_peft_model = None
@@ -95,11 +100,12 @@ def generate_text(
95
  num_beams=1,
96
  )
97
 
98
- output = model.generate( # type: ignore
99
- input_ids=input_ids,
100
- attention_mask=torch.ones_like(input_ids),
101
- generation_config=generation_config
102
- )[0].cuda()
 
103
 
104
  return tokenizer.decode(output, skip_special_tokens=True).strip()
105
 
@@ -238,6 +244,8 @@ def tokenize_and_train(
238
 
239
  result = trainer.train(resume_from_checkpoint=False)
240
  model.save_pretrained(output_dir)
 
 
241
  reset_model()
242
 
243
  return result
 
1
  import os
2
+ import gc
3
  import argparse
4
  import random
5
  import torch
 
45
  del model
46
  del tokenizer
47
 
48
+ gc.collect()
49
+ with torch.no_grad():
50
+ torch.cuda.empty_cache()
51
+
52
  model = None
53
  tokenizer = None
54
  current_peft_model = None
 
100
  num_beams=1,
101
  )
102
 
103
+ with torch.no_grad():
104
+ output = model.generate( # type: ignore
105
+ input_ids=input_ids,
106
+ attention_mask=torch.ones_like(input_ids),
107
+ generation_config=generation_config
108
+ )[0].cuda()
109
 
110
  return tokenizer.decode(output, skip_special_tokens=True).strip()
111
 
 
244
 
245
  result = trainer.train(resume_from_checkpoint=False)
246
  model.save_pretrained(output_dir)
247
+
248
+ del data
249
  reset_model()
250
 
251
  return result