NRJ-DEBUG / configuration_energy.py
TCMVince's picture
commit files to HF hub
2b5b2f3
from math import sqrt,log
import sys
#sys.path.append("../energy") # Messy
import torch
import torch.nn as nn
from torch.nn.functional import softmax,relu,linear
from common import PositionalEncoding
from hopfield import HopfieldLayer, HopfieldMHA, HopfieldReLU, HopfieldSoftmax
from torch.cuda.amp import autocast
import yaml
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import MaskedLMOutput, BaseModelOutput
class BertEnergyConfig(PretrainedConfig):
model_type = "bert_energy"
def __init__(self, config=None, path=None, vocabulary_size=50, num_layers=12, num_heads=12, forward_memories=2048, embedding_dim=768, activation="relu",positional=True, bias=True, tie_weights=True, alpha=1.0,
beta=1., layer_norm=1e-05, dropout=0.0, block_size=512, share_layers=False, compile=False, pad_idx=None, **kwargs):
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_heads = num_heads
self.activation = activation
self.positional = positional
self.tie_weights = tie_weights
self.bias = bias
self.forward_memories = forward_memories
self.embedding_dim = embedding_dim
self.share_layers = share_layers
self.alpha = alpha
self.beta = beta
self.layer_norm = float(layer_norm)
self.dropout = dropout
self.block_size = block_size
self.compile = compile
self.pad_idx = pad_idx
if config is not None:
for key,value in config.to_dict():
if key.lower() in self.__dict__.keys():
print(key, file=sys.stderr)
setattr(self,key.lower(),value)
elif path is not None:
if path.endswith(".yaml"):
with open(path) as istream:
config = yaml.safe_load(istream)
for key,value in config.items():
print(key)
if key.lower() in self.__dict__.keys():
setattr(self,key.lower(),value)
else:
raise NotImplementedError
super().__init__(**kwargs)