mrm8488 commited on
Commit
181aab7
1 Parent(s): ade9726

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -127
utils.py DELETED
@@ -1,127 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
- from torch.cuda.amp import custom_fwd, custom_bwd
5
-
6
- from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
7
-
8
-
9
-
10
- class FrozenBNBLinear(nn.Module):
11
- def __init__(self, weight, absmax, code, bias=None):
12
- assert isinstance(bias, nn.Parameter) or bias is None
13
- super().__init__()
14
- self.out_features, self.in_features = weight.shape
15
- self.register_buffer("weight", weight.requires_grad_(False))
16
- self.register_buffer("absmax", absmax.requires_grad_(False))
17
- self.register_buffer("code", code.requires_grad_(False))
18
- self.adapter = None
19
- self.bias = bias
20
-
21
- def forward(self, input):
22
- output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
23
- if self.adapter:
24
- output += self.adapter(input)
25
- return output
26
-
27
- @classmethod
28
- def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
29
- weights_int8, state = quantize_blockise_lowmemory(linear.weight)
30
- return cls(weights_int8, *state, linear.bias)
31
-
32
- def __repr__(self):
33
- return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
34
-
35
-
36
- class DequantizeAndLinear(torch.autograd.Function):
37
- @staticmethod
38
- @custom_fwd
39
- def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
40
- absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
41
- weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
42
- ctx.save_for_backward(input, weights_quantized, absmax, code)
43
- ctx._has_bias = bias is not None
44
- return F.linear(input, weights_deq, bias)
45
-
46
- @staticmethod
47
- @custom_bwd
48
- def backward(ctx, grad_output: torch.Tensor):
49
- assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
50
- input, weights_quantized, absmax, code = ctx.saved_tensors
51
- # grad_output: [*batch, out_features]
52
- weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
53
- grad_input = grad_output @ weights_deq
54
- grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
55
- return grad_input, None, None, None, grad_bias
56
-
57
-
58
- class FrozenBNBEmbedding(nn.Module):
59
- def __init__(self, weight, absmax, code):
60
- super().__init__()
61
- self.num_embeddings, self.embedding_dim = weight.shape
62
- self.register_buffer("weight", weight.requires_grad_(False))
63
- self.register_buffer("absmax", absmax.requires_grad_(False))
64
- self.register_buffer("code", code.requires_grad_(False))
65
- self.adapter = None
66
-
67
- def forward(self, input, **kwargs):
68
- with torch.no_grad():
69
- # note: both quantuized weights and input indices are *not* differentiable
70
- weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
71
- output = F.embedding(input, weight_deq, **kwargs)
72
- if self.adapter:
73
- output += self.adapter(input)
74
- return output
75
-
76
- @classmethod
77
- def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
78
- weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
79
- return cls(weights_int8, *state)
80
-
81
- def __repr__(self):
82
- return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
83
-
84
-
85
- def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
86
- assert chunk_size % 4096 == 0
87
- code = None
88
- chunks = []
89
- absmaxes = []
90
- flat_tensor = matrix.view(-1)
91
- for i in range((matrix.numel() - 1) // chunk_size + 1):
92
- input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
93
- quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
94
- chunks.append(quantized_chunk)
95
- absmaxes.append(absmax_chunk)
96
-
97
- matrix_i8 = torch.cat(chunks).reshape_as(matrix)
98
- absmax = torch.cat(absmaxes)
99
- return matrix_i8, (absmax, code)
100
-
101
-
102
- def convert_to_int8(model):
103
- """Convert linear and embedding modules to 8-bit with optional adapters"""
104
- for module in list(model.modules()):
105
- for name, child in module.named_children():
106
- if isinstance(child, nn.Linear):
107
- print(name, child)
108
- setattr(
109
- module,
110
- name,
111
- FrozenBNBLinear(
112
- weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
113
- absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
114
- code=torch.zeros(256),
115
- bias=child.bias,
116
- ),
117
- )
118
- elif isinstance(child, nn.Embedding):
119
- setattr(
120
- module,
121
- name,
122
- FrozenBNBEmbedding(
123
- weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
124
- absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
125
- code=torch.zeros(256),
126
- )
127
- )