Enkhai commited on
Commit
a0cd7ee
1 Parent(s): ea09821
Files changed (5) hide show
  1. config.json +7 -2
  2. config.py +7 -0
  3. gptj.py +74 -0
  4. lora.py +99 -0
  5. pytorch_model.bin +2 -2
config.json CHANGED
@@ -1,10 +1,15 @@
1
  {
2
- "_name_or_path": "hivemind/gpt-j-6B-8bit",
3
  "activation_function": "gelu_new",
 
4
  "architectures": [
5
- "GPTJForCausalLM"
6
  ],
7
  "attn_pdrop": 0.0,
 
 
 
 
8
  "bos_token_id": 50256,
9
  "eight_bit": true,
10
  "embd_pdrop": 0.0,
 
1
  {
2
+ "_name_or_path": "EleutherAI/gpt-j-6B",
3
  "activation_function": "gelu_new",
4
+ "add_apapters": true,
5
  "architectures": [
6
+ "GPTJLoraForCausalLM"
7
  ],
8
  "attn_pdrop": 0.0,
9
+ "auto_map": {
10
+ "AutoConfig": "config.GPTJLoraConfig",
11
+ "AutoModel": "gptj.GPTJLoraForCausalLM"
12
+ },
13
  "bos_token_id": 50256,
14
  "eight_bit": true,
15
  "embd_pdrop": 0.0,
config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from transformers import GPTJConfig
2
+
3
+
4
+ class GPTJLoraConfig(GPTJConfig):
5
+ def __init__(self, add_adapters=False, **kwargs):
6
+ self.add_apapters = add_adapters
7
+ super().__init__(**kwargs)
gptj.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .lora import FrozenBNBLinear, FrozenBNBEmbedding
4
+ import transformers
5
+
6
+
7
+ def add_adapters(model, adapter_dim=16):
8
+ assert adapter_dim > 0
9
+
10
+ for module in model.modules():
11
+ if isinstance(module, FrozenBNBLinear):
12
+ module.adapter = nn.Sequential(
13
+ nn.Linear(module.in_features, adapter_dim, bias=False),
14
+ nn.Linear(adapter_dim, module.out_features, bias=False),
15
+ )
16
+ nn.init.zeros_(module.adapter[1].weight)
17
+ elif isinstance(module, FrozenBNBEmbedding):
18
+ module.adapter = nn.Sequential(
19
+ nn.Embedding(module.num_embeddings, adapter_dim),
20
+ nn.Linear(adapter_dim, module.embedding_dim, bias=False),
21
+ )
22
+ nn.init.zeros_(module.adapter[1].weight)
23
+
24
+
25
+ def convert_to_int8(model):
26
+ """Convert linear and embedding modules to 8-bit with optional adapters"""
27
+ for module in list(model.modules()):
28
+ for name, child in module.named_children():
29
+ if isinstance(child, nn.Linear):
30
+ setattr(
31
+ module,
32
+ name,
33
+ FrozenBNBLinear(
34
+ weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
35
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
36
+ code=torch.zeros(256),
37
+ bias=child.bias,
38
+ ),
39
+ )
40
+ elif isinstance(child, nn.Embedding):
41
+ setattr(
42
+ module,
43
+ name,
44
+ FrozenBNBEmbedding(
45
+ weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
46
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
47
+ code=torch.zeros(256),
48
+ )
49
+ )
50
+
51
+
52
+ class GPTJLoraBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
53
+ def __init__(self, config):
54
+ super().__init__(config)
55
+
56
+ convert_to_int8(self.attn)
57
+ convert_to_int8(self.mlp)
58
+
59
+
60
+ class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
61
+ def __init__(self, config):
62
+ super().__init__(config)
63
+ convert_to_int8(self)
64
+
65
+
66
+ class GPTJLoraForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
67
+ def __init__(self, config):
68
+ super().__init__(config)
69
+ convert_to_int8(self)
70
+ if config.add_apapters:
71
+ add_adapters(self)
72
+
73
+
74
+ transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJLoraBlock # monkey-patch GPT-J
lora.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from torch.cuda.amp import custom_fwd, custom_bwd
5
+
6
+ from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
7
+
8
+
9
+ def quantize_blockwise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
10
+ assert chunk_size % 4096 == 0
11
+ code = None
12
+ chunks = []
13
+ absmaxes = []
14
+ flat_tensor = matrix.view(-1)
15
+ for i in range((matrix.numel() - 1) // chunk_size + 1):
16
+ input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
17
+ quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
18
+ chunks.append(quantized_chunk)
19
+ absmaxes.append(absmax_chunk)
20
+
21
+ matrix_i8 = torch.cat(chunks).reshape_as(matrix)
22
+ absmax = torch.cat(absmaxes)
23
+ return matrix_i8, (absmax, code)
24
+
25
+
26
+ class FrozenBNBLinear(nn.Module):
27
+ def __init__(self, weight, absmax, code, bias=None):
28
+ assert isinstance(bias, nn.Parameter) or bias is None
29
+ super().__init__()
30
+ self.out_features, self.in_features = weight.shape
31
+ self.register_buffer("weight", weight.requires_grad_(False))
32
+ self.register_buffer("absmax", absmax.requires_grad_(False))
33
+ self.register_buffer("code", code.requires_grad_(False))
34
+ self.adapter = None
35
+ self.bias = bias
36
+
37
+ def forward(self, input):
38
+ output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias).clone()
39
+ if self.adapter:
40
+ output += self.adapter(input)
41
+ return output
42
+
43
+ @classmethod
44
+ def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
45
+ weights_int8, state = quantize_blockwise_lowmemory(linear.weight)
46
+ return cls(weights_int8, *state, linear.bias)
47
+
48
+ def __repr__(self):
49
+ return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
50
+
51
+
52
+ class DequantizeAndLinear(torch.autograd.Function):
53
+ @staticmethod
54
+ @custom_fwd
55
+ def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
56
+ absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
57
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
58
+ ctx.save_for_backward(input, weights_quantized, absmax, code)
59
+ ctx._has_bias = bias is not None
60
+ return F.linear(input, weights_deq, bias)
61
+
62
+ @staticmethod
63
+ @custom_bwd
64
+ def backward(ctx, grad_output: torch.Tensor):
65
+ assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
66
+ input, weights_quantized, absmax, code = ctx.saved_tensors
67
+ # grad_output: [*batch, out_features]
68
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
69
+ grad_input = grad_output @ weights_deq
70
+ grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
71
+ return grad_input, None, None, None, grad_bias
72
+
73
+
74
+ class FrozenBNBEmbedding(nn.Module):
75
+ def __init__(self, weight, absmax, code):
76
+ super().__init__()
77
+ self.num_embeddings, self.embedding_dim = weight.shape
78
+ self.register_buffer("weight", weight.requires_grad_(False))
79
+ self.register_buffer("absmax", absmax.requires_grad_(False))
80
+ self.register_buffer("code", code.requires_grad_(False))
81
+ self.adapter = None
82
+
83
+ def forward(self, input, **kwargs):
84
+ with torch.no_grad():
85
+ # note: both quantuized weights and input indices are *not* differentiable
86
+ weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
87
+ output = F.embedding(input, weight_deq, **kwargs)
88
+ if self.adapter:
89
+ output += self.adapter(input)
90
+ return output
91
+
92
+ @classmethod
93
+ def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
94
+ weights_int8, state = quantize_blockwise_lowmemory(embedding.weight)
95
+ return cls(weights_int8, *state)
96
+
97
+ def __repr__(self):
98
+ return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
99
+
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f7731cedf324c3aad10e4c9461e6807ad5a6e96f2f849b1d4f0c4556f4b44957
3
- size 6316424352
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:853de3e0341394ced7526eb6286644559127d32c3882c1c4c0bdb978871fe665
3
+ size 6316410080