PatrickHaller commited on
Commit
f57435d
1 Parent(s): 33340e5

Upload configuration_xlstm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_xlstm.py +97 -0
configuration_xlstm.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Any, Dict, Optional
3
+
4
+ from dacite import Config as DaciteConfig
5
+ from dacite import from_dict
6
+ from omegaconf import OmegaConf
7
+ from transformers.configuration_utils import PretrainedConfig
8
+ from xlstm import xLSTMLMModelConfig
9
+
10
+ # from .config_presets import xlstm_cfg_map
11
+
12
+
13
+ class xLSTMConfig(PretrainedConfig):
14
+ """XLSTM configuration class.
15
+ We seperate the specific xLSTM model configuration
16
+ from the rest due to the heavy nesting of the configuration.
17
+ """
18
+
19
+ model_type = "xlstm"
20
+
21
+ def __init__(
22
+ self, vocab_size: int = 32000, config: Optional[Dict[str, Any]] = None, **kwargs
23
+ ):
24
+ super().__init__(**kwargs)
25
+
26
+ cfg = OmegaConf.create(config)
27
+ cfg["vocab_size"] = vocab_size
28
+ for key, value in kwargs.items():
29
+ cfg[key] = value
30
+
31
+ self._xlstm_config = cfg
32
+ self.vocab_size = vocab_size
33
+ self.embedding_dim = cfg.get("embedding_dim")
34
+ self.context_length = cfg.get("context_length")
35
+
36
+ def to_xlstm_config(self):
37
+ return from_dict(
38
+ data_class=xLSTMLMModelConfig,
39
+ data=OmegaConf.to_container(self._xlstm_config),
40
+ config=DaciteConfig(strict=True),
41
+ )
42
+
43
+ def to_dict(self) -> Dict[str, Any]:
44
+ """
45
+ Converts the configuration to a dictionary for serialization.
46
+ """
47
+ output = super().to_dict()
48
+ output["_xlstm_config"] = OmegaConf.to_container(
49
+ self._xlstm_config, resolve=True
50
+ )
51
+ relevant_keys = [
52
+ "vocab_size",
53
+ "embedding_dim",
54
+ "context_length",
55
+ "torch_dtype",
56
+ "_xlstm_config",
57
+ "transformers_version",
58
+ "architectures",
59
+ "model_type",
60
+ ]
61
+ output_ = output.copy()
62
+ for key in output.keys():
63
+ if key not in relevant_keys:
64
+ output_.pop(key)
65
+ return output_
66
+
67
+ @classmethod
68
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
69
+ """
70
+ Creates a configuration instance from a dictionary.
71
+ """
72
+ xlstm_config = config_dict.pop("_xlstm_config")
73
+ vocab_size = config_dict.pop("vocab_size")
74
+ config = cls(vocab_size=vocab_size, config=xlstm_config)
75
+ if "auto_map" in config_dict and config_dict["auto_map"]:
76
+ setattr(config, "auto_map", config_dict.pop("auto_map"))
77
+
78
+ # breakpoint()
79
+ # config.xlstm_config = xlstm_config
80
+ if "return_unused_kwargs" in kwargs and kwargs["return_unused_kwargs"]:
81
+ return config, {}
82
+
83
+ return config
84
+
85
+ def to_json_string(self, *args, **kwargs) -> str:
86
+ """
87
+ Serializes the instance to a JSON string.
88
+ """
89
+ return json.dumps(self.to_dict(), indent=2)
90
+
91
+ @classmethod
92
+ def from_json_string(cls, json_string: str):
93
+ """
94
+ Deserializes the instance from a JSON string.
95
+ """
96
+ config_dict = json.loads(json_string)
97
+ return cls.from_dict(config_dict)