Enkhai commited on
Commit
668789a
1 Parent(s): 58d858e

Upload GPTJLoraForCausalLM

Browse files
Files changed (5) hide show
  1. config.json +5 -2
  2. config.py +10 -10
  3. gptj.py +86 -86
  4. lora.py +99 -99
  5. pytorch_model.bin +1 -1
config.json CHANGED
@@ -1,6 +1,7 @@
1
  {
 
2
  "activation_function": "gelu_new",
3
- "add_adapters": true,
4
  "architectures": [
5
  "GPTJLoraForCausalLM"
6
  ],
@@ -10,6 +11,7 @@
10
  "AutoModelForCausalLM": "gptj.GPTJLoraForCausalLM"
11
  },
12
  "bos_token_id": 50256,
 
13
  "embd_pdrop": 0.0,
14
  "eos_token_id": 50256,
15
  "gradient_checkpointing": false,
@@ -39,7 +41,8 @@
39
  },
40
  "tie_word_embeddings": false,
41
  "tokenizer_class": "GPT2Tokenizer",
42
- "transformers_version": "4.20.1",
 
43
  "use_cache": true,
44
  "vocab_size": 50400
45
  }
 
1
  {
2
+ "_name_or_path": "gpt-j-6b-8bit-lora",
3
  "activation_function": "gelu_new",
4
+ "add_apapters": true,
5
  "architectures": [
6
  "GPTJLoraForCausalLM"
7
  ],
 
11
  "AutoModelForCausalLM": "gptj.GPTJLoraForCausalLM"
12
  },
13
  "bos_token_id": 50256,
14
+ "eight_bit": true,
15
  "embd_pdrop": 0.0,
16
  "eos_token_id": 50256,
17
  "gradient_checkpointing": false,
 
41
  },
42
  "tie_word_embeddings": false,
43
  "tokenizer_class": "GPT2Tokenizer",
44
+ "torch_dtype": "float32",
45
+ "transformers_version": "4.24.0",
46
  "use_cache": true,
47
  "vocab_size": 50400
48
  }
config.py CHANGED
@@ -1,10 +1,10 @@
1
- from transformers import GPTJConfig
2
-
3
-
4
- class GPTJLoraConfig(GPTJConfig):
5
- model_type = "gptj-lora"
6
-
7
- def __init__(self, add_adapters=False, **kwargs):
8
- self.add_apapters = add_adapters
9
- super().__init__(**kwargs)
10
- self.model_type = "gptj-lora"
 
1
+ from transformers import GPTJConfig
2
+
3
+
4
+ class GPTJLoraConfig(GPTJConfig):
5
+ model_type = "gptj-lora"
6
+
7
+ def __init__(self, add_adapters=False, **kwargs):
8
+ self.add_apapters = add_adapters
9
+ super().__init__(**kwargs)
10
+ self.model_type = "gptj-lora"
gptj.py CHANGED
@@ -1,86 +1,86 @@
1
- import torch
2
- from torch import nn
3
- from .lora import FrozenBNBLinear, FrozenBNBEmbedding
4
- from .config import GPTJLoraConfig
5
- import transformers
6
-
7
-
8
- def add_adapters(model, adapter_dim=16):
9
- assert adapter_dim > 0
10
-
11
- for module in model.modules():
12
- if isinstance(module, FrozenBNBLinear):
13
- module.adapter = nn.Sequential(
14
- nn.Linear(module.in_features, adapter_dim, bias=False),
15
- nn.Linear(adapter_dim, module.out_features, bias=False),
16
- )
17
- nn.init.zeros_(module.adapter[1].weight)
18
- elif isinstance(module, FrozenBNBEmbedding):
19
- module.adapter = nn.Sequential(
20
- nn.Embedding(module.num_embeddings, adapter_dim),
21
- nn.Linear(adapter_dim, module.embedding_dim, bias=False),
22
- )
23
- nn.init.zeros_(module.adapter[1].weight)
24
-
25
-
26
- def convert_to_int8(model):
27
- """Convert linear and embedding modules to 8-bit with optional adapters"""
28
- for module in list(model.modules()):
29
- for name, child in module.named_children():
30
- if isinstance(child, nn.Linear):
31
- setattr(
32
- module,
33
- name,
34
- FrozenBNBLinear(
35
- weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
36
- absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
37
- code=torch.zeros(256),
38
- bias=child.bias,
39
- ),
40
- )
41
- elif isinstance(child, nn.Embedding):
42
- setattr(
43
- module,
44
- name,
45
- FrozenBNBEmbedding(
46
- weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
47
- absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
48
- code=torch.zeros(256),
49
- )
50
- )
51
-
52
-
53
- class GPTJLoraBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
54
- config_class = GPTJLoraConfig
55
-
56
- def __init__(self, config):
57
- super().__init__(config)
58
- self.config_class = GPTJLoraConfig
59
-
60
- convert_to_int8(self.attn)
61
- convert_to_int8(self.mlp)
62
-
63
-
64
- class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
65
- config_class = GPTJLoraConfig
66
-
67
- def __init__(self, config):
68
- super().__init__(config)
69
- self.config_class = GPTJLoraConfig
70
-
71
- convert_to_int8(self)
72
-
73
-
74
- class GPTJLoraForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
75
- config_class = GPTJLoraConfig
76
-
77
- def __init__(self, config):
78
- super().__init__(config)
79
- self.config_class = GPTJLoraConfig
80
-
81
- convert_to_int8(self)
82
- if config.add_apapters:
83
- add_adapters(self)
84
-
85
-
86
- transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJLoraBlock # monkey-patch GPT-J
 
1
+ import torch
2
+ from torch import nn
3
+ from .lora import FrozenBNBLinear, FrozenBNBEmbedding
4
+ from .config import GPTJLoraConfig
5
+ import transformers
6
+
7
+
8
+ def add_adapters(model, adapter_dim=16):
9
+ assert adapter_dim > 0
10
+
11
+ for module in model.modules():
12
+ if isinstance(module, FrozenBNBLinear):
13
+ module.adapter = nn.Sequential(
14
+ nn.Linear(module.in_features, adapter_dim, bias=False),
15
+ nn.Linear(adapter_dim, module.out_features, bias=False),
16
+ )
17
+ nn.init.zeros_(module.adapter[1].weight)
18
+ elif isinstance(module, FrozenBNBEmbedding):
19
+ module.adapter = nn.Sequential(
20
+ nn.Embedding(module.num_embeddings, adapter_dim),
21
+ nn.Linear(adapter_dim, module.embedding_dim, bias=False),
22
+ )
23
+ nn.init.zeros_(module.adapter[1].weight)
24
+
25
+
26
+ def convert_to_int8(model):
27
+ """Convert linear and embedding modules to 8-bit with optional adapters"""
28
+ for module in list(model.modules()):
29
+ for name, child in module.named_children():
30
+ if isinstance(child, nn.Linear):
31
+ setattr(
32
+ module,
33
+ name,
34
+ FrozenBNBLinear(
35
+ weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
36
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
37
+ code=torch.zeros(256),
38
+ bias=child.bias,
39
+ ),
40
+ )
41
+ elif isinstance(child, nn.Embedding):
42
+ setattr(
43
+ module,
44
+ name,
45
+ FrozenBNBEmbedding(
46
+ weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
47
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
48
+ code=torch.zeros(256),
49
+ )
50
+ )
51
+
52
+
53
+ class GPTJLoraBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
54
+ config_class = GPTJLoraConfig
55
+
56
+ def __init__(self, config):
57
+ super().__init__(config)
58
+ self.config_class = GPTJLoraConfig
59
+
60
+ convert_to_int8(self.attn)
61
+ convert_to_int8(self.mlp)
62
+
63
+
64
+ class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
65
+ config_class = GPTJLoraConfig
66
+
67
+ def __init__(self, config):
68
+ super().__init__(config)
69
+ self.config_class = GPTJLoraConfig
70
+
71
+ convert_to_int8(self)
72
+
73
+
74
+ class GPTJLoraForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
75
+ config_class = GPTJLoraConfig
76
+
77
+ def __init__(self, config):
78
+ super().__init__(config)
79
+ self.config_class = GPTJLoraConfig
80
+
81
+ convert_to_int8(self)
82
+ if config.add_apapters:
83
+ add_adapters(self)
84
+
85
+
86
+ transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJLoraBlock # monkey-patch GPT-J
lora.py CHANGED
@@ -1,99 +1,99 @@
1
- import torch
2
- from torch import nn
3
- import torch.nn.functional as F
4
- from torch.cuda.amp import custom_fwd, custom_bwd
5
-
6
- from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
7
-
8
-
9
- def quantize_blockwise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
10
- assert chunk_size % 4096 == 0
11
- code = None
12
- chunks = []
13
- absmaxes = []
14
- flat_tensor = matrix.view(-1)
15
- for i in range((matrix.numel() - 1) // chunk_size + 1):
16
- input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
17
- quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
18
- chunks.append(quantized_chunk)
19
- absmaxes.append(absmax_chunk)
20
-
21
- matrix_i8 = torch.cat(chunks).reshape_as(matrix)
22
- absmax = torch.cat(absmaxes)
23
- return matrix_i8, (absmax, code)
24
-
25
-
26
- class FrozenBNBLinear(nn.Module):
27
- def __init__(self, weight, absmax, code, bias=None):
28
- assert isinstance(bias, nn.Parameter) or bias is None
29
- super().__init__()
30
- self.out_features, self.in_features = weight.shape
31
- self.register_buffer("weight", weight.requires_grad_(False))
32
- self.register_buffer("absmax", absmax.requires_grad_(False))
33
- self.register_buffer("code", code.requires_grad_(False))
34
- self.adapter = None
35
- self.bias = bias
36
-
37
- def forward(self, input):
38
- output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias).clone()
39
- if self.adapter:
40
- output += self.adapter(input)
41
- return output
42
-
43
- @classmethod
44
- def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
45
- weights_int8, state = quantize_blockwise_lowmemory(linear.weight)
46
- return cls(weights_int8, *state, linear.bias)
47
-
48
- def __repr__(self):
49
- return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
50
-
51
-
52
- class DequantizeAndLinear(torch.autograd.Function):
53
- @staticmethod
54
- @custom_fwd
55
- def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
56
- absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
57
- weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
58
- ctx.save_for_backward(input, weights_quantized, absmax, code)
59
- ctx._has_bias = bias is not None
60
- return F.linear(input, weights_deq, bias)
61
-
62
- @staticmethod
63
- @custom_bwd
64
- def backward(ctx, grad_output: torch.Tensor):
65
- assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
66
- input, weights_quantized, absmax, code = ctx.saved_tensors
67
- # grad_output: [*batch, out_features]
68
- weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
69
- grad_input = grad_output @ weights_deq
70
- grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
71
- return grad_input, None, None, None, grad_bias
72
-
73
-
74
- class FrozenBNBEmbedding(nn.Module):
75
- def __init__(self, weight, absmax, code):
76
- super().__init__()
77
- self.num_embeddings, self.embedding_dim = weight.shape
78
- self.register_buffer("weight", weight.requires_grad_(False))
79
- self.register_buffer("absmax", absmax.requires_grad_(False))
80
- self.register_buffer("code", code.requires_grad_(False))
81
- self.adapter = None
82
-
83
- def forward(self, input, **kwargs):
84
- with torch.no_grad():
85
- # note: both quantized weights and input indices are *not* differentiable
86
- weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
87
- output = F.embedding(input, weight_deq, **kwargs)
88
- if self.adapter:
89
- output += self.adapter(input)
90
- return output
91
-
92
- @classmethod
93
- def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
94
- weights_int8, state = quantize_blockwise_lowmemory(embedding.weight)
95
- return cls(weights_int8, *state)
96
-
97
- def __repr__(self):
98
- return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
99
-
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from torch.cuda.amp import custom_fwd, custom_bwd
5
+
6
+ from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
7
+
8
+
9
+ def quantize_blockwise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
10
+ assert chunk_size % 4096 == 0
11
+ code = None
12
+ chunks = []
13
+ absmaxes = []
14
+ flat_tensor = matrix.view(-1)
15
+ for i in range((matrix.numel() - 1) // chunk_size + 1):
16
+ input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
17
+ quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
18
+ chunks.append(quantized_chunk)
19
+ absmaxes.append(absmax_chunk)
20
+
21
+ matrix_i8 = torch.cat(chunks).reshape_as(matrix)
22
+ absmax = torch.cat(absmaxes)
23
+ return matrix_i8, (absmax, code)
24
+
25
+
26
+ class FrozenBNBLinear(nn.Module):
27
+ def __init__(self, weight, absmax, code, bias=None):
28
+ assert isinstance(bias, nn.Parameter) or bias is None
29
+ super().__init__()
30
+ self.out_features, self.in_features = weight.shape
31
+ self.register_buffer("weight", weight.requires_grad_(False))
32
+ self.register_buffer("absmax", absmax.requires_grad_(False))
33
+ self.register_buffer("code", code.requires_grad_(False))
34
+ self.adapter = None
35
+ self.bias = bias
36
+
37
+ def forward(self, input):
38
+ output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias).clone()
39
+ if self.adapter:
40
+ output += self.adapter(input)
41
+ return output
42
+
43
+ @classmethod
44
+ def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
45
+ weights_int8, state = quantize_blockwise_lowmemory(linear.weight)
46
+ return cls(weights_int8, *state, linear.bias)
47
+
48
+ def __repr__(self):
49
+ return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
50
+
51
+
52
+ class DequantizeAndLinear(torch.autograd.Function):
53
+ @staticmethod
54
+ @custom_fwd
55
+ def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
56
+ absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
57
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
58
+ ctx.save_for_backward(input, weights_quantized, absmax, code)
59
+ ctx._has_bias = bias is not None
60
+ return F.linear(input, weights_deq, bias)
61
+
62
+ @staticmethod
63
+ @custom_bwd
64
+ def backward(ctx, grad_output: torch.Tensor):
65
+ assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
66
+ input, weights_quantized, absmax, code = ctx.saved_tensors
67
+ # grad_output: [*batch, out_features]
68
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
69
+ grad_input = grad_output @ weights_deq
70
+ grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
71
+ return grad_input, None, None, None, grad_bias
72
+
73
+
74
+ class FrozenBNBEmbedding(nn.Module):
75
+ def __init__(self, weight, absmax, code):
76
+ super().__init__()
77
+ self.num_embeddings, self.embedding_dim = weight.shape
78
+ self.register_buffer("weight", weight.requires_grad_(False))
79
+ self.register_buffer("absmax", absmax.requires_grad_(False))
80
+ self.register_buffer("code", code.requires_grad_(False))
81
+ self.adapter = None
82
+
83
+ def forward(self, input, **kwargs):
84
+ with torch.no_grad():
85
+ # note: both quantized weights and input indices are *not* differentiable
86
+ weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
87
+ output = F.embedding(input, weight_deq, **kwargs)
88
+ if self.adapter:
89
+ output += self.adapter(input)
90
+ return output
91
+
92
+ @classmethod
93
+ def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
94
+ weights_int8, state = quantize_blockwise_lowmemory(embedding.weight)
95
+ return cls(weights_int8, *state)
96
+
97
+ def __repr__(self):
98
+ return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
99
+
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:153cb853074d3fb66c18f93b78297f0d88e252eb1f2a2e5779dff97453a63124
3
  size 6316410080
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10793d174ead92956a981a490ea62ebd2d2109ed944f8fb2fa2815e987988449
3
  size 6316410080