Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2022 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Helper routines for quantization.""" | |
from typing import Any | |
import chex | |
import jax.numpy as jnp | |
from flax import struct | |
# pylint:disable=no-value-for-parameter | |
class QuantizedValue: | |
"""State associated with quantized value.""" | |
quantized: chex.Array | |
diagonal: chex.Array # Diagonal (if extract_diagonal is set) | |
bucket_size: chex.Array | |
quantized_dtype: jnp.dtype = struct.field( | |
pytree_node=False | |
) # Dtype for the quantized value. | |
extract_diagonal: bool = struct.field(pytree_node=False) # In case its centered. | |
shape: Any = struct.field(pytree_node=False) # Shape of the tensor. | |
def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False): | |
if isinstance(fvalue, list) and not fvalue: | |
return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, []) | |
quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize( | |
fvalue, quantized_dtype, extract_diagonal | |
) | |
return QuantizedValue( | |
quantized, | |
diagonal_fvalue, | |
bucket_size, | |
quantized_dtype, | |
extract_diagonal, | |
list(quantized.shape), | |
) | |
# Quantization is from Lingvo JAX optimizers. | |
# We extend it for int16 quantization of PSD matrices. | |
def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False): | |
"""Returns quantized value and the bucket.""" | |
if quantized_dtype == jnp.float32: | |
return fvalue, [], [] | |
elif quantized_dtype == jnp.bfloat16: | |
return fvalue.astype(jnp.bfloat16), [], [] | |
float_dtype = fvalue.dtype | |
if quantized_dtype == jnp.int8: | |
# value -128 is not used. | |
num_buckets = jnp.array(127.0, dtype=float_dtype) | |
elif quantized_dtype == jnp.int16: | |
# value -32768 is not used. | |
num_buckets = jnp.array(32767.0, dtype=float_dtype) | |
else: | |
raise ValueError(f"Quantized dtype {quantized_dtype} not supported.") | |
# max value is mapped to num_buckets | |
if extract_diagonal and fvalue.ndim != 2: | |
raise ValueError( | |
f"Input array {fvalue} must be 2D to work with extract_diagonal." | |
) | |
diagonal_fvalue = [] | |
if extract_diagonal: | |
diagonal_fvalue = jnp.diag(fvalue) | |
# Remove the diagonal entries. | |
fvalue = fvalue - jnp.diag(diagonal_fvalue) | |
# TODO(rohananil): Extend this by making use of information about the blocks | |
# SM3 style which will be useful for diagonal statistics | |
# We first decide the scale. | |
if fvalue.ndim < 1: | |
raise ValueError( | |
f"Input array {fvalue} must have a strictly positive number of " | |
"dimensions." | |
) | |
max_abs = jnp.max(jnp.abs(fvalue), axis=0) | |
bucket_size = max_abs / num_buckets | |
bs_expanded = bucket_size[jnp.newaxis, Ellipsis] | |
# To avoid divide by 0.0 | |
bs_nonzero = jnp.where( | |
bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded) | |
) | |
ratio = fvalue / bs_nonzero | |
# We use rounding to remove bias. | |
quantized = jnp.round(ratio) | |
return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size | |
def to_float(self): | |
"""Returns the float value.""" | |
if isinstance(self.quantized, list) and not self.quantized: | |
return self.quantized | |
if self.quantized_dtype == jnp.float32: | |
return self.quantized | |
if self.quantized_dtype == jnp.bfloat16: | |
return self.quantized.astype(jnp.float32) | |
float_dtype = self.bucket_size.dtype | |
bucket_size = self.bucket_size[jnp.newaxis, Ellipsis] | |
val = self.quantized.astype(float_dtype) * bucket_size | |
if self.extract_diagonal: | |
val += jnp.diag(self.diagonal) | |
return val | |