Fix cache default value (#8)
Browse files- Fix cache default value (eb8c33d7cd6a1fefa1d977aebd39ae331185ec5d)
- configuration_cased.py +3 -2
configuration_cased.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
|
3 |
from transformers.modeling_utils import PretrainedConfig
|
4 |
|
@@ -21,11 +22,11 @@ class CaSEDConfig(PretrainedConfig):
|
|
21 |
index_name: str = "cc12m",
|
22 |
alpha: float = 0.5,
|
23 |
retrieval_num_results: int = 10,
|
24 |
-
cache_dir: str =
|
25 |
**kwargs,
|
26 |
):
|
27 |
super().__init__(**kwargs)
|
28 |
self.index_name = index_name
|
29 |
self.alpha = alpha
|
30 |
self.retrieval_num_results = retrieval_num_results
|
31 |
-
self.cache_dir = cache_dir
|
|
|
1 |
import os
|
2 |
+
from typing import Optional
|
3 |
|
4 |
from transformers.modeling_utils import PretrainedConfig
|
5 |
|
|
|
22 |
index_name: str = "cc12m",
|
23 |
alpha: float = 0.5,
|
24 |
retrieval_num_results: int = 10,
|
25 |
+
cache_dir: Optional[str] = None,
|
26 |
**kwargs,
|
27 |
):
|
28 |
super().__init__(**kwargs)
|
29 |
self.index_name = index_name
|
30 |
self.alpha = alpha
|
31 |
self.retrieval_num_results = retrieval_num_results
|
32 |
+
self.cache_dir = cache_dir or os.path.expanduser("~/.cache/cased")
|