Spaces:
Paused
Paused
File size: 527 Bytes
75fa479 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import dataclasses
import torch
from transformers.models.opt.configuration_opt import OPTConfig
@dataclasses.dataclass(frozen=True)
class TricksyConfig:
opt_config: OPTConfig
# Percentage of weights to keep on each device
# e.g. 30% of each MLP layer on GPU
min_mlp_sparsity_gpu: float = .3
# e.g. 100% of each MLP layer on CPU
min_mlp_sparsity_cpu: float = 1
# If true, cleans up layer's weights after computing forward pass
full_offload: bool = False
dtype: torch.dtype = torch.float16 |