tricksy / configuration_tricksy.py
austinsilveria's picture
Add tricksy
75fa479
import dataclasses
import torch
from transformers.models.opt.configuration_opt import OPTConfig
@dataclasses.dataclass(frozen=True)
class TricksyConfig:
opt_config: OPTConfig
# Percentage of weights to keep on each device
# e.g. 30% of each MLP layer on GPU
min_mlp_sparsity_gpu: float = .3
# e.g. 100% of each MLP layer on CPU
min_mlp_sparsity_cpu: float = 1
# If true, cleans up layer's weights after computing forward pass
full_offload: bool = False
dtype: torch.dtype = torch.float16