import bz2 import torch import base64 import ctypes from transformers.utils import logging from typing import List logger = logging.get_logger(__name__) try: from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up class Kernel: def __init__(self, code: bytes, function_names: List[str]): self.code = code self._function_names = function_names self._cmodule = LazyKernelCModule(self.code) for name in self._function_names: setattr(self, name, KernelFunction(self._cmodule, name)) quantization_code = "QlpoOTFBWSZTWapgbn4ALTZ/////////9f/n9+/r/v//3/Tt7cDwfe5sdXXdZNR/9++P4BkfAfIVSVACQUFCEIJQoJUqkgAACJRSqAAAQgEqKqqoKoAAVINADQDTCMhoAGCZBpoNMg0Bk0AGRiGmQBghkNGmgAaGgGhkyAAADJoMhBoAaAaYRkNAAwTINNBpkGgMmgAyMQ0yAMEMho00ADQ0A0MmQAAAZNBkINADQDTCMhoAGCZBpoNMg0Bk0AGRiGmQBghkNGmgAaGgGhkyAAADJoMhBoAaAaYRkNAAwTINNBpkGgMmgAyMQ0yAMEMho00ADQ0A0MmQAAAZNBkCapJBTEyQ0nqeSZioDeqNlG1PU2o9Go9TTQ09QaPUD0CPSAA9Q9QPKA0AGQaANHqBo0GgNGgB6gaApSSCAJiIwmhlNGmjEmCYmTKbSnqP1J6m8gp56ppN6pp5R7U9Q9JMjxTbIk9Mj1BtUMCeoaAB+qNHqAyaNpHtoqfqcgV6L7jHRddxuUdeeq6Vy1s1q0PGdR98uY7Jv2di02u25eLvCuiUfprinG43A5DrtXzq4XdXQ7HSYXSK3F/dxtXgcbZptcVjbJtpppp/Jf7H57vPBXRiuNZXU0OXvTI3V0W5xmyvSd95rhM4s5HF6D4c/uT8etNOq3zqacJzz03vLmXS4Ti2aY77qnmeN2ze315l+oX4t89/x/uvSeN03beGb50V1vfv6lfNn12n6b0Nr8O/F31q1q6ToOydk+ONGjRhho7FPPyPwV4R1FsuRZeCbL18PZnyFyjkGq+Ies8bnvM1MaZeeT30nU/N9Pp+n+f+V+X3gHeXPR8p8ufwnqncyf4K4K2m5qaj1z2mi04YafvWjcbODfNTcysmVvfrTZjhVkHhejPQeOeZ4p5nknpOSZORxb5wdM783XnrL1K6lli2Oqt87d6wxoq9A/sudf+W6sMUfSaLwsnTN1/8P6Uyxj/fPzay3SR2UxR1BeGaFV7DFFecsUrcenP69fnv+pkybyK9J3V5ZpqYxqanlMnjm+OqO56ZxeWPhrS0vUt7DlmN1p3pckN+KtJYG+cJ3nHdBXlG5cTg3Ttraw5zi0vVz7jzst5xk3rRqVcjCLknW9BsR79+RYyYzSuLBWVvm9o+83EtLZ672k6tvu1Xo2MWpHxd6r2g8tdw5n8p/3GOByV4H8N+K7p+W5DDcuDhPi+JjMY6U+y/sP1GOntZY8C6+dj0maa1YYysvzXR0N1jWrucLof8TTZyB45+y6nFjg/z69n0TMdBu1kbTHvWR3mZvN+zc00PBYssS6HoPcW7g7bUzKwy1Y9c5ji4ufZYcurGDGGMzGMZTGVjLLBXdcro2lT5SwSfEmSqfBMpVxLGGTKxkwsr2E0T3yYkH6DPqsYy0zTGmO8ZS4Xv7G1YxjaloxJ3Y4vrmqo5m5qT7Jbm/aIca41org/rW6bJxTDgY7u31qhvMTpViiu/uvaK4LEptss0nNYk08D89snSsWLBPxFinvnjR7TGODdV8KyfMup+o2rmRjGLK/ErRxyWPKtRuYxh6DFteVdC1NnqHf1F02DtrErv3bdBsZeZP0VtpNLSxiNYV8BZDqLJ51g9ZXYmWJss2t9i4OJ8hbU3uNobm1cLibk4le1PGjxY2jnS8kcJP1JZLuhncnnjScVhX7Bf0TVXyy+CakeSMk9Bo1X0JyPKbQsrvLVLhWJTwGRXlsVdtcvwGMcN0nrsocD6Y9htSjvMReLCx9i9PaU+wwScpkj02EXpV51fKk2qHW7zRTUydk9k5W1epWXltGktGU6S7DMsmWWJcxkc1W0m6PhRhYsCxZL3C7btuqrcq3WL8A63U2WP1WO1tL4NZJvna1XblwdrlxtQ2d9p1f2WLBlqpi5Wi8U4LVjnY/SYXU3Rq4xpUPZuStUdjrvTi8StNHaOS7VVxVwrt2rmrrDTmS5jRWjE5phjDkrFwum1W9LhYai4OuNSca3JbzDYqWjAuPPedpdWzRlWWKbqV5NcDUjDI5JhoyXKsfnPOvymmOFW5RwmTzrLo78mvRdl+A4LhXCuC6vWrDsaNDNrSw4zDS9V6q3t0my3XDK1bB0VlTaHI5myxblfqn9ph0VwriA4jHMZYZRcito3m83W+sWTG4piuqytrJYLFtq1ZmWrWlixZLGGq4XdDRxOJqjz4vZWR58ZXfRlGoy/KlkNRqlvj8mS2W5ZZVe7NHKN9cpyngzbY2NWsWdWrzPM7/3H1f0/1nJV/9OXU6w0yw8NtNYGbJeJjZbXUZjzMeJuvKYvgNNGMaZuTE3G9h8dPBN28mo8L0a3E7lcK0NKyi98dIXNfoHSPUHGcbJ4U945n0jGn0V2D//d45VNjue9eC1k+Vew1Yw2N3FPZbq2sq3LFyOg8j77ctzkUeAxWLCauB1upch/ib0/JMowwyxNjIY5jUuaOjyVs6KsrnhzGczEtssWSy3WZjGWquxNDRgyMpjDEwWW1qWjKxZczLanB0vvuh6Det7gOpYzGctNLLdWaNMzTLS1WYmGWDbGGlq5lhpuatWtGpWI3GW0ppqqmcV2nfcW9Li52FojtWJXwcFeN7E/Dr48+G+292xp/hf5Z+M/4HRWXtpdq/jX+N5q8ReyK6tHoXyi1DsX0zwW1tlcHkN7H7C7l5a1au+sfwL4N/guN5li+BfvL3t7w+nPeP6Ttr6k/Xfqz91+3P3z2z+uXJWj9Gp9lZtXhpYeRf16RelfDx5tbPBs9IxVdowJ7d8Vgww0X23RWkyspfOYn4RMfiPntI+pVl819O+l+A+E3Mzsf5181q97fOMfUTc3ORjmhxXpjFi0dN4LXdsuOq0undqvur5Jj131ThWPoL9ReB1dK/E415JMC+Sb00nVdHXOqch5ZwTvpe7ucytpPsa61y1zLdMcNKfhmzcdLdGVkaYqaMF9A4tUri6XUxtI6jjNhyzlZvmuFbm0NWBqc60TVdl+df6LtVc9bzhGLmSxLtutNy6gwp1kuFcK61ul4w55LtqtVbE/kNzSvCyR4THEyI3MJyMFYeU1OUMlY7n6A69wOl2Hm2bRt2fRlz210T7Lc4hvrIOFXItTLIwadavpRz19Z+6uVe6/vLwq9B6DnH1ZNNy7l1suc1Xk8Tsq83CrvPCP5zXUaa7xyOLzPx3nfoOqNMqu5mzLXpOB/qHkk3sk00mzELVhJwq5G5vV3q3rV31jRjRpXKYlcjkYfTPVN0e+dTi7Ha6nA4Pf16L+0+S3vxntMeQ94dFs3PbPVaaY0eJ4n12LDmerNzmdVeo3jZ0Wn6TztPAY7nfr8r+XufE+rp/K9z+XyPC3zD0WGP902PXvh9ravQvM7fjz584OBXqDj4V5K9ls8TG3C9byv9tsbzDmcrGcWr23E5XqOven2Hqtzmb/OdBj4Drb3oZvxh81qTZstK3MaYx7xps9tonI3mqnla6q+BuWx2T+O/vv8b3vl9Zvdb+/XMruD0z4rpdzQxeVbOR3q5J3THBvPA3tzyP8BTnGH3jENMMyxlMsxg0sxi1aasZS633V9BdDpYx0nq+E6HvHS1fMdRp53c077vvVNjnnfdbZ52HQ6GnOx52mlxMrifkJ3P4H333X1n663ud4n8Jzl5TmdDyPmF+c/cd5xcgD8x0v9puXlOx0HfeL4ryD2fBXTXbX0a+jX5E9weUcTicTicTicTicFXA4HA4HA4HhXm+VetOVcxh4p3LV5Myp1UvGPRPcnqD8Q2NjY2NjY2NjhXuUserunePXK67236bxvkMfCcyu2uZ8dy258NqOlwrcHJZecxORjTpelVvbljocqbI53xWZjlcXl+W+Yuze35blS9+MVctlmKc2VdNh8rjTwHcX5rdXynpPYdFmPw30HBte5bNNm+1egy2vC8TqbNNmm1dK8VtX/s+e/i9jdU8BiDhed2um3V2saPM4ydlp2Mb2m+652Ox8KsbrLlnh5KPVGSded05W584ceNxVev7y3XJbo7DptKOplXrsE6WSfEN903xvyZR8B8R8R73dJyhwjtPcJ6NNyK8lWI3PXbm0bOq7/HpmrdUdC4GjvbLjYsWJ7z4J49j4Vo0f4U/bb6zMf9OGlu3zJ/2za/xP8fG4eAOc7E6ho+uWLqGPFxd727L8aq5f/b0n/MtrjX1r+Bafhm3Ce5WmVixZWNPuNGPdvj2x7rLgZcMMmrlO0Buc7fml0arVwuE47mYw0maYynSMORNTUbjhnI1bGN1XKP2zLGpfJb8ZlwXO0ZjGlz4f1l0LhG6MXzHYHcZVt+EaTxlj/cMm5h+cyHrvzGX8SenN09OZOVkYNPLfqk3e4+8WoXmr5NfbH0XD8iT435PY4uVgxo0000f8q/FvjG+vEzGSuxcLV75vrU8aT9K/GclTcTJMFjIc75taDx2DnXPbrRtjv5qnjxMrpco5Z+NacOENnW0bORzOsxwcH0HFcjneNpC/iri4t9cCfzq/erS70xHTX753jYBxrBisj1cU6LIO2bpNo2V9Qak1TF4H2k8rHsNMewfOXyvwr+McsncOT1j/vObbPYPsLvsHx33mn8R7r+ibzH9Ns7jZp8e/eb3zXBwc5+y24vmOZztqvkN6z2j13or4K9dXiV76vXeeGq9k4GyppTHvZuT4Qr5F+fW9Hxji+kux1Sb57OC+57rR7rBvmJMGTHTVonx3M7xsbN5pvPSq1VirJNVqrpbDtna3zYnfXzK/91fTVP8qb02TYjUFZGUskslT6Bf0WK9rHvY+wN1irdG6NVujaNo1G0exl3aXjnQuQdEyVo7vkktptbrL2024zUv+KZGSslZGRhixZYXVPcHSrpLIFi0/qIG5xjzv7xzvEY3SfivEafafzD+Y0MdLGnGuxWmzHQ/0Di2O9VXI2cF3pzXLdGTnzWtD4T5DsW1eCudDg4ty9qd4/sOJeivWODcvq1xbLmfj1qub6ldHQ8LsOCr126tLjw5utuV6ECx0XBrOH2XlfwMaae0556C/nmGMbDTD9+sDmI/BvcvktXvlb1v+adfWu46W4Z6S0cFvfra47b16DcvCcjg2bMq4OXfbQew9sYxjGLSfaXz13Lr21lllqy2twX75Vfu2AyH0YPf24+yjorQ5+8h0S7dR/rH2qVqDdXUk+Sk5IxXPQf9ir234LZzKm8dy41xkt0PgR76HjQ6JL3SqmSOtJ3UHKcRXRKn+Rsf8j7ir7kC6/UfLOA/TUPSgXdAu4V0VV1q4zxsP4xznEWxYnoSP1ivCrxq7wNVTdC0D0lVOeCvEqOi3R3sLnna9RMsY+MamYxjUf2V661P47v2nXtXzs5wetm+bTIdBwh6e+WvvbVc+bZP7Yy1YvBXOqxJ30PIV7fuqx9SMmMMY7LVjG+S6IyPiqp4Eakui7CvM7Kv2EnFUjjAvfFewVwVOSS7Mao9TDxTlVfClhkspcsB2PcfAe8nYOww6SmqwRlZCysNdLaej4FTwo9nwXBZaWl5uatmmrfbrZbHBYNmrFtbJ7xtbTG6xYerb7LL2UmT9ur54D+5V6tPU3ozGSyZPEV4S6JYfyT7qaNGzVNDs8Feer6i9gfBrFpbzjfYr0Z4huT5zJk5VWwfzX45uNljyNDG05zZZKdtWi/ecFpYbIu+m0cllZZVlZHQPYDnh469YW+YyzGOs1S5DzmlyGLc62l2rDyDKuQ9jai4VknV0tK2tS0sZjDWVzstxY3VaP4afXf6L6++rhYwx42nFYXeLpTqVp3DHVMeFwJplHv2RyzDe41pTmr9dquVk5WI+7tqN7xMY51NrCwxYxWMMMUODVd94IflreuiuDS4Q8B+01PdWIeR7Z2PSboDmZBzTZbm0vlm+YamDe71cHDeLe3NLQx35tNze/Zd4OtZJ4H5jleDtaO1dLTc/Yf6Tit0uDFV/VZHc77k3njf+BwXQY7l3VpE6HQ1bjTVdzLwTiWkXW61jv/bfz2MftP2mn8hsbNmnfe1dlci/cpeBZe9O9Mb3ad5J4ZMoeVvXnvE3Vf8IfuPGfXMf/K309N6h2v/Rji5HmfdXovWedzucx3N7e+s7GzZvK6n3jxvabNmzc+0/BWl3nefl8G5cr7RzMaaP8jFsvuXMcFe0713O7ZfWn3pPaq4rdKejPByrzPUN6cBH6LZsTzcXrHkrrdfMONXA5hhyybH4Ru4LktC3T4dvvdXM7ThXPdPp7m9yTlH6VdKvbzBd4ZjJll3nAi+2sX31iU4m52d+92rlq6g9cftu+HkO1bx/+LuSKcKEhVMDc/A=" kernels = Kernel( bz2.decompress(base64.b64decode(quantization_code)), [ "weightInt8_int4", "weightInt4_fp16", "weightInt4_bf16" ], ) except Exception as exception: kernels = None logger.warning("Failed to load cpm_kernels:" + str(exception)) def quantize_int8(weight: torch.Tensor, bit_length: int): weight_scale = weight.abs().max(dim=-1).values / ((2 ** (bit_length - 1)) - 1) weight_scale = weight_scale.to(torch.float32) weight = torch.round(weight.to(weight_scale.dtype) / weight_scale[:, None]).to(torch.int8) return weight, weight_scale def compress_int4_weight(weight: torch.Tensor): with torch.cuda.device(weight.device): num_row, num_chan = weight.size(0), weight.size(1) num_chan = num_chan // 2 int8_weight = torch.empty(num_row, num_chan, dtype=torch.int8, device="cuda") stream = torch.cuda.current_stream() dim_grid = (num_row, 1, 1) dim_block = (min(round_up(num_chan, 32), 1024), 1, 1) kernels.weightInt8_int4( dim_grid, dim_block, 0, stream, [ ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(int8_weight.data_ptr()), ctypes.c_int32(num_row), ctypes.c_int32(num_chan) ], ) return int8_weight def dequantize_float(weight: torch.Tensor, weight_scale: torch.Tensor, bit_length: int, input: torch.Tensor): if bit_length == 8: float_weight = weight.to(input.dtype) * weight_scale.to(input.dtype)[:, None] return float_weight assert bit_length == 4, f"unsupported bit length: {bit_length}" func = ( kernels.weightInt4_fp16 if input.dtype == torch.half else kernels.weightInt4_bf16 ) with torch.cuda.device(weight.device): num_row, num_chan = weight.size(0), weight.size(1) float_weight = torch.empty(num_row, num_chan * 2, dtype=input.dtype, device="cuda") stream = torch.cuda.current_stream() dim_grid = (num_row, 1, 1) dim_block = (min(round_up(num_chan, 32), 1024), 1, 1) func( dim_grid, dim_block, 0, stream, [ ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(weight_scale.data_ptr()), ctypes.c_void_p(float_weight.data_ptr()), ctypes.c_int32(num_row), ctypes.c_int32(num_chan), ], ) return float_weight class QuantizationLinear(torch.nn.Module): def __init__(self, bit_length: int, weight: torch.Tensor, device="cuda"): super().__init__() self.bit_length = bit_length weight, weight_scale = quantize_int8(weight=weight, bit_length=bit_length) if bit_length == 4: weight = compress_int4_weight(weight) self.weight = torch.nn.Parameter(weight.to(device), requires_grad=False) self.weight_scale = torch.nn.Parameter(weight_scale.to(device), requires_grad=False) def forward(self, input: torch.Tensor): input_size = input.size() input = input.contiguous().view(-1, input.size(-1)) original_weight = dequantize_float(self.weight, self.weight_scale, self.bit_length, input) output = torch.matmul(input, original_weight.t()) return output.view(*(input_size[:-1] + (self.weight.size(0),)))