pipeline1 / configuration_stacked.py
Gleb Vinarskis
debug
51f79b2
raw
history blame contribute delete
684 Bytes
from transformers import PretrainedConfig
import torch
class ImpressoConfig(PretrainedConfig):
model_type = "floret"
def __init__(self, filename="LID-40-3-2000000-1-4.bin", **kwargs):
super().__init__(**kwargs)
self.filename = filename
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
# Bypass JSON loading and create config directly
print(f"Loading ImpressoConfig from {pretrained_model_name_or_path}")
config = cls(filename="LID-40-3-2000000-1-4.bin", **kwargs)
return config
# Register the configuration with the transformers library
ImpressoConfig.register_for_auto_class()