zetavg commited on
Commit
966795b
1 Parent(s): fa0947b

cache lora models

Browse files
llama_lora/globals.py CHANGED
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
6
  from numba import cuda
7
  import nvidia_smi
8
 
 
9
  from .lib.finetune import train
10
 
11
 
@@ -31,8 +32,7 @@ class Global:
31
 
32
  # Model related
33
  model_has_been_used = False
34
- loaded_base_model_with_lora = None
35
- loaded_base_model_with_lora_name = None
36
 
37
  # GPU Info
38
  gpu_cc = None # GPU compute capability
 
6
  from numba import cuda
7
  import nvidia_smi
8
 
9
+ from .utils.lru_cache import LRUCache
10
  from .lib.finetune import train
11
 
12
 
 
32
 
33
  # Model related
34
  model_has_been_used = False
35
+ cached_lora_models = LRUCache(10)
 
36
 
37
  # GPU Info
38
  gpu_cc = None # GPU compute capability
llama_lora/models.py CHANGED
@@ -31,30 +31,32 @@ def get_base_model():
31
  return Global.loaded_base_model
32
 
33
 
34
- def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
35
  Global.model_has_been_used = True
36
 
37
- if Global.loaded_base_model_with_lora and Global.loaded_base_model_with_lora_name == lora_weights:
38
- return Global.loaded_base_model_with_lora
 
 
39
 
40
  if device == "cuda":
41
  model = PeftModel.from_pretrained(
42
  get_base_model(),
43
- lora_weights,
44
  torch_dtype=torch.float16,
45
  device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
46
  )
47
  elif device == "mps":
48
  model = PeftModel.from_pretrained(
49
  get_base_model(),
50
- lora_weights,
51
  device_map={"": device},
52
  torch_dtype=torch.float16,
53
  )
54
  else:
55
  model = PeftModel.from_pretrained(
56
  get_base_model(),
57
- lora_weights,
58
  device_map={"": device},
59
  )
60
 
@@ -69,8 +71,9 @@ def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
69
  if torch.__version__ >= "2" and sys.platform != "win32":
70
  model = torch.compile(model)
71
 
72
- Global.loaded_base_model_with_lora = model
73
- Global.loaded_base_model_with_lora_name = lora_weights
 
74
  return model
75
 
76
 
@@ -127,10 +130,7 @@ def unload_models():
127
  del Global.loaded_tokenizer
128
  Global.loaded_tokenizer = None
129
 
130
- del Global.loaded_base_model_with_lora
131
- Global.loaded_base_model_with_lora = None
132
-
133
- Global.loaded_base_model_with_lora_name = None
134
 
135
  clear_cache()
136
 
 
31
  return Global.loaded_base_model
32
 
33
 
34
+ def get_model_with_lora(lora_weights_name_or_path: str = "tloen/alpaca-lora-7b"):
35
  Global.model_has_been_used = True
36
 
37
+ if Global.cached_lora_models:
38
+ model_from_cache = Global.cached_lora_models.get(lora_weights_name_or_path)
39
+ if model_from_cache:
40
+ return model_from_cache
41
 
42
  if device == "cuda":
43
  model = PeftModel.from_pretrained(
44
  get_base_model(),
45
+ lora_weights_name_or_path,
46
  torch_dtype=torch.float16,
47
  device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
48
  )
49
  elif device == "mps":
50
  model = PeftModel.from_pretrained(
51
  get_base_model(),
52
+ lora_weights_name_or_path,
53
  device_map={"": device},
54
  torch_dtype=torch.float16,
55
  )
56
  else:
57
  model = PeftModel.from_pretrained(
58
  get_base_model(),
59
+ lora_weights_name_or_path,
60
  device_map={"": device},
61
  )
62
 
 
71
  if torch.__version__ >= "2" and sys.platform != "win32":
72
  model = torch.compile(model)
73
 
74
+ if Global.cached_lora_models:
75
+ Global.cached_lora_models.set(lora_weights_name_or_path, model)
76
+
77
  return model
78
 
79
 
 
130
  del Global.loaded_tokenizer
131
  Global.loaded_tokenizer = None
132
 
133
+ Global.cached_lora_models.clear()
 
 
 
134
 
135
  clear_cache()
136
 
requirements.lock.txt CHANGED
@@ -65,7 +65,7 @@ packaging==23.0
65
  pandas==2.0.0
66
  parso==0.8.3
67
  pathspec==0.11.1
68
- peft @ git+https://github.com/huggingface/peft.git@deff03f2c251534fffd2511fc2d440e84cc54b1b
69
  pexpect==4.8.0
70
  pickleshare==0.7.5
71
  Pillow==9.3.0
 
65
  pandas==2.0.0
66
  parso==0.8.3
67
  pathspec==0.11.1
68
+ peft @ git+https://github.com/huggingface/peft.git@382b178911edff38c1ff619bbac2ba556bd2276b
69
  pexpect==4.8.0
70
  pickleshare==0.7.5
71
  Pillow==9.3.0