Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
•
966795b
1
Parent(s):
fa0947b
cache lora models
Browse files- llama_lora/globals.py +2 -2
- llama_lora/models.py +12 -12
- requirements.lock.txt +1 -1
llama_lora/globals.py
CHANGED
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6 |
from numba import cuda
|
7 |
import nvidia_smi
|
8 |
|
|
|
9 |
from .lib.finetune import train
|
10 |
|
11 |
|
@@ -31,8 +32,7 @@ class Global:
|
|
31 |
|
32 |
# Model related
|
33 |
model_has_been_used = False
|
34 |
-
|
35 |
-
loaded_base_model_with_lora_name = None
|
36 |
|
37 |
# GPU Info
|
38 |
gpu_cc = None # GPU compute capability
|
|
|
6 |
from numba import cuda
|
7 |
import nvidia_smi
|
8 |
|
9 |
+
from .utils.lru_cache import LRUCache
|
10 |
from .lib.finetune import train
|
11 |
|
12 |
|
|
|
32 |
|
33 |
# Model related
|
34 |
model_has_been_used = False
|
35 |
+
cached_lora_models = LRUCache(10)
|
|
|
36 |
|
37 |
# GPU Info
|
38 |
gpu_cc = None # GPU compute capability
|
llama_lora/models.py
CHANGED
@@ -31,30 +31,32 @@ def get_base_model():
|
|
31 |
return Global.loaded_base_model
|
32 |
|
33 |
|
34 |
-
def get_model_with_lora(
|
35 |
Global.model_has_been_used = True
|
36 |
|
37 |
-
if Global.
|
38 |
-
|
|
|
|
|
39 |
|
40 |
if device == "cuda":
|
41 |
model = PeftModel.from_pretrained(
|
42 |
get_base_model(),
|
43 |
-
|
44 |
torch_dtype=torch.float16,
|
45 |
device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
|
46 |
)
|
47 |
elif device == "mps":
|
48 |
model = PeftModel.from_pretrained(
|
49 |
get_base_model(),
|
50 |
-
|
51 |
device_map={"": device},
|
52 |
torch_dtype=torch.float16,
|
53 |
)
|
54 |
else:
|
55 |
model = PeftModel.from_pretrained(
|
56 |
get_base_model(),
|
57 |
-
|
58 |
device_map={"": device},
|
59 |
)
|
60 |
|
@@ -69,8 +71,9 @@ def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
|
|
69 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
70 |
model = torch.compile(model)
|
71 |
|
72 |
-
Global.
|
73 |
-
|
|
|
74 |
return model
|
75 |
|
76 |
|
@@ -127,10 +130,7 @@ def unload_models():
|
|
127 |
del Global.loaded_tokenizer
|
128 |
Global.loaded_tokenizer = None
|
129 |
|
130 |
-
|
131 |
-
Global.loaded_base_model_with_lora = None
|
132 |
-
|
133 |
-
Global.loaded_base_model_with_lora_name = None
|
134 |
|
135 |
clear_cache()
|
136 |
|
|
|
31 |
return Global.loaded_base_model
|
32 |
|
33 |
|
34 |
+
def get_model_with_lora(lora_weights_name_or_path: str = "tloen/alpaca-lora-7b"):
|
35 |
Global.model_has_been_used = True
|
36 |
|
37 |
+
if Global.cached_lora_models:
|
38 |
+
model_from_cache = Global.cached_lora_models.get(lora_weights_name_or_path)
|
39 |
+
if model_from_cache:
|
40 |
+
return model_from_cache
|
41 |
|
42 |
if device == "cuda":
|
43 |
model = PeftModel.from_pretrained(
|
44 |
get_base_model(),
|
45 |
+
lora_weights_name_or_path,
|
46 |
torch_dtype=torch.float16,
|
47 |
device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
|
48 |
)
|
49 |
elif device == "mps":
|
50 |
model = PeftModel.from_pretrained(
|
51 |
get_base_model(),
|
52 |
+
lora_weights_name_or_path,
|
53 |
device_map={"": device},
|
54 |
torch_dtype=torch.float16,
|
55 |
)
|
56 |
else:
|
57 |
model = PeftModel.from_pretrained(
|
58 |
get_base_model(),
|
59 |
+
lora_weights_name_or_path,
|
60 |
device_map={"": device},
|
61 |
)
|
62 |
|
|
|
71 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
72 |
model = torch.compile(model)
|
73 |
|
74 |
+
if Global.cached_lora_models:
|
75 |
+
Global.cached_lora_models.set(lora_weights_name_or_path, model)
|
76 |
+
|
77 |
return model
|
78 |
|
79 |
|
|
|
130 |
del Global.loaded_tokenizer
|
131 |
Global.loaded_tokenizer = None
|
132 |
|
133 |
+
Global.cached_lora_models.clear()
|
|
|
|
|
|
|
134 |
|
135 |
clear_cache()
|
136 |
|
requirements.lock.txt
CHANGED
@@ -65,7 +65,7 @@ packaging==23.0
|
|
65 |
pandas==2.0.0
|
66 |
parso==0.8.3
|
67 |
pathspec==0.11.1
|
68 |
-
peft @ git+https://github.com/huggingface/peft.git@
|
69 |
pexpect==4.8.0
|
70 |
pickleshare==0.7.5
|
71 |
Pillow==9.3.0
|
|
|
65 |
pandas==2.0.0
|
66 |
parso==0.8.3
|
67 |
pathspec==0.11.1
|
68 |
+
peft @ git+https://github.com/huggingface/peft.git@382b178911edff38c1ff619bbac2ba556bd2276b
|
69 |
pexpect==4.8.0
|
70 |
pickleshare==0.7.5
|
71 |
Pillow==9.3.0
|