|
<!DOCTYPE html> |
|
<html xmlns="http://www.w3.org/1999/xhtml" lang="" xml:lang=""> |
|
<head> |
|
<meta charset="utf-8" /> |
|
<meta name="generator" content="pandoc" /> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes" /> |
|
<title>QuIP#</title> |
|
<style> |
|
html { |
|
font-family: cambria; |
|
font-size: 16pt; |
|
color: #1a1a1a; |
|
background-color: #fdfdfd; |
|
} |
|
body { |
|
margin: 0 auto; |
|
max-width: 36em; |
|
padding-left: 50px; |
|
padding-right: 50px; |
|
padding-top: 50px; |
|
padding-bottom: 50px; |
|
hyphens: auto; |
|
overflow-wrap: break-word; |
|
text-rendering: optimizeLegibility; |
|
font-kerning: normal; |
|
} |
|
@media (max-width: 600px) { |
|
body { |
|
font-size: 0.9em; |
|
padding: 12px; |
|
} |
|
h1 { |
|
font-size: 1.8em; |
|
} |
|
} |
|
@media print { |
|
html { |
|
background-color: white; |
|
} |
|
body { |
|
background-color: transparent; |
|
color: black; |
|
font-size: 12pt; |
|
} |
|
p, h2, h3 { |
|
orphans: 3; |
|
widows: 3; |
|
} |
|
h2, h3, h4 { |
|
page-break-after: avoid; |
|
} |
|
} |
|
p { |
|
margin: 1em 0; |
|
} |
|
a { |
|
color: #1a1a1a; |
|
} |
|
a:visited { |
|
color: #1a1a1a; |
|
} |
|
img { |
|
max-width: 100%; |
|
} |
|
svg { |
|
height: auto; |
|
max-width: 100%; |
|
} |
|
h1, h2, h3, h4, h5, h6 { |
|
margin-top: 1.4em; |
|
} |
|
h5, h6 { |
|
font-size: 1em; |
|
font-style: italic; |
|
} |
|
h6 { |
|
font-weight: normal; |
|
} |
|
ol, ul { |
|
padding-left: 1.7em; |
|
margin-top: 1em; |
|
} |
|
li > ol, li > ul { |
|
margin-top: 0; |
|
} |
|
blockquote { |
|
margin: 1em 0 1em 1.7em; |
|
padding-left: 1em; |
|
border-left: 2px solid #e6e6e6; |
|
color: #606060; |
|
} |
|
code { |
|
font-family: Menlo, Monaco, Consolas, 'Lucida Console', monospace; |
|
font-size: 85%; |
|
margin: 0; |
|
hyphens: manual; |
|
} |
|
pre { |
|
margin: 1em 0; |
|
overflow: auto; |
|
} |
|
pre code { |
|
padding: 0; |
|
overflow: visible; |
|
overflow-wrap: normal; |
|
} |
|
.sourceCode { |
|
background-color: transparent; |
|
overflow: visible; |
|
} |
|
hr { |
|
background-color: #1a1a1a; |
|
border: none; |
|
height: 1px; |
|
margin: 1em 0; |
|
} |
|
table { |
|
margin: 1em 0; |
|
border-collapse: collapse; |
|
width: 100%; |
|
overflow-x: auto; |
|
display: block; |
|
font-variant-numeric: lining-nums tabular-nums; |
|
} |
|
table caption { |
|
margin-bottom: 0.75em; |
|
} |
|
tbody { |
|
margin-top: 0.5em; |
|
border-top: 1px solid #1a1a1a; |
|
border-bottom: 1px solid #1a1a1a; |
|
} |
|
th { |
|
border-top: 1px solid #1a1a1a; |
|
padding: 0.25em 0.5em 0.25em 0.5em; |
|
} |
|
td { |
|
padding: 0.125em 0.5em 0.25em 0.5em; |
|
} |
|
header { |
|
margin-bottom: 4em; |
|
text-align: center; |
|
} |
|
#TOC li { |
|
list-style: none; |
|
} |
|
#TOC ul { |
|
padding-left: 1.3em; |
|
} |
|
#TOC > ul { |
|
padding-left: 0; |
|
} |
|
#TOC a:not(:hover) { |
|
text-decoration: none; |
|
} |
|
code{white-space: pre-wrap;} |
|
span.smallcaps{font-variant: small-caps;} |
|
div.columns{display: flex; gap: min(4vw, 1.5em);} |
|
div.column{flex: auto; overflow-x: auto;} |
|
div.hanging-indent{margin-left: 1.5em; text-indent: -1.5em;} |
|
|
|
|
|
ul.task-list[class]{list-style: none;} |
|
ul.task-list li input[type="checkbox"] { |
|
font-size: inherit; |
|
width: 0.8em; |
|
margin: 0 0.8em 0.2em -1.6em; |
|
vertical-align: middle; |
|
} |
|
</style> |
|
<script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> |
|
<script |
|
src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-chtml-full.js" |
|
type="text/javascript"></script> |
|
|
|
|
|
|
|
</head> |
|
<body> |
|
<style> |
|
body { max-width: 800px !important; text-align: justify; } |
|
tbody { |
|
border-top: none; |
|
border-bottom: none; |
|
} |
|
header { height:0px;} |
|
tr:nth-child(2n) { |
|
background-color: #EEEEEE; |
|
} |
|
th { |
|
background-color: #EEEEEE; |
|
} |
|
</style> |
|
<h2 id="quip-quip-with-lattice-codebooks">QuIP#: <a |
|
href="https://github.com/jerry-chee/QuIP">QuIP</a> with Lattice |
|
Codebooks</h2> |
|
<p><a href="https://tsengalb99.github.io">Albert Tseng*</a>, <a |
|
href="https://jerry-chee.github.io/">Jerry Chee*</a>, <a |
|
href="https://nalzok.github.io/">Qingyao Sun</a>, <a |
|
href="https://www.cs.cornell.edu/~kuleshov/">Volodymyr Kuleshov</a>, and |
|
<a href="https://www.cs.cornell.edu/~cdesa/">Chris De Sa</a></p> |
|
<hr /> |
|
<p><img src="img/overview.svg" /></p> |
|
<p>Large language models (LLMs) exhibit amazing performance on a wide |
|
variety of tasks such as text modeling and code generation. However, |
|
they are also very large. For example Llama 2 70B has 70 billion |
|
parameters that require 140GB of memory to store in half precision. This |
|
presents many challenges, such as needing multiple GPUs just to serve a |
|
single LLM. To address these issues, researchers have developed |
|
compression methods that reduce the size of models without destroying |
|
performance.</p> |
|
<p>One class of methods, post-training quantization, compresses trained |
|
model weights into lower precision formats to reduce memory |
|
requirements. For example, quantizing a model from 16 bit to 2 bit |
|
precision would reduce the size of the model by 8x, meaning that even |
|
Llama 2 70B would fit on a single 24GB GPU. In this work, we introduce |
|
<strong>QuIP#</strong>, which combines lattice codebooks with |
|
incoherence processing to create state-of-the-art 2 bit quantized |
|
models. These two methods allow QuIP# to significantly close the gap |
|
between 2 bit quantized LLMs and unquantized 16 bit models.</p> |
|
<div style="margin-left: auto; |
|
margin-right: auto; |
|
width: 90%;"> |
|
<table style="width:100%;"> |
|
<caption>Quantization results on Llama 2 70B. QuIP# achieves near-native |
|
performance at 2 bits, outperforming all other presented |
|
baselines.</caption> |
|
<colgroup> |
|
<col style="width: 16%" /> |
|
<col style="width: 16%" /> |
|
<col style="width: 16%" /> |
|
<col style="width: 16%" /> |
|
<col style="width: 16%" /> |
|
<col style="width: 16%" /> |
|
</colgroup> |
|
<thead> |
|
<tr class="header"> |
|
<th style="text-align: center;">Method</th> |
|
<th style="text-align: center;">Precision</th> |
|
<th style="text-align: center;">Wiki <span |
|
class="math inline">\(\downarrow\)</span></th> |
|
<th style="text-align: center;">C4 <span |
|
class="math inline">\(\downarrow\)</span></th> |
|
<th style="text-align: center;">ArcE <span |
|
class="math inline">\(\uparrow\)</span></th> |
|
<th style="text-align: center;">PiQA <span |
|
class="math inline">\(\uparrow\)</span></th> |
|
</tr> |
|
</thead> |
|
<tbody> |
|
<tr class="odd"> |
|
<td style="text-align: center;">Native</td> |
|
<td style="text-align: center;">16 bit</td> |
|
<td style="text-align: center;">3.120</td> |
|
<td style="text-align: center;">5.533</td> |
|
<td style="text-align: center;">0.597</td> |
|
<td style="text-align: center;">0.809</td> |
|
</tr> |
|
<tr class="even"> |
|
<td style="text-align: center;">OPTQ</td> |
|
<td style="text-align: center;">3 bit</td> |
|
<td style="text-align: center;">4.577</td> |
|
<td style="text-align: center;">6.838</td> |
|
<td style="text-align: center;">0.544</td> |
|
<td style="text-align: center;"><strong>0.786</strong></td> |
|
</tr> |
|
<tr class="odd"> |
|
<td style="text-align: center;">OPTQ</td> |
|
<td style="text-align: center;">2 bit</td> |
|
<td style="text-align: center;">109.820</td> |
|
<td style="text-align: center;">62.692</td> |
|
<td style="text-align: center;">0.253</td> |
|
<td style="text-align: center;">0.505</td> |
|
</tr> |
|
<tr class="even"> |
|
<td style="text-align: center;">QuIP</td> |
|
<td style="text-align: center;">2 bit</td> |
|
<td style="text-align: center;">5.574</td> |
|
<td style="text-align: center;">8.268</td> |
|
<td style="text-align: center;">0.544</td> |
|
<td style="text-align: center;">0.751</td> |
|
</tr> |
|
<tr class="odd"> |
|
<td style="text-align: center;"><strong>QuIP#</strong></td> |
|
<td style="text-align: center;"><strong>2 bit</strong></td> |
|
<td style="text-align: center;"><strong>4.159</strong></td> |
|
<td style="text-align: center;"><strong>6.529</strong></td> |
|
<td style="text-align: center;"><strong>0.595</strong></td> |
|
<td style="text-align: center;">0.786</td> |
|
</tr> |
|
</tbody> |
|
</table> |
|
</div> |
|
<div |
|
style="color:steelblue; margin-left: -14%; margin-right: auto; width: 115%"> |
|
<table> |
|
<colgroup> |
|
<col style="width: 3%" /> |
|
<col style="width: 96%" /> |
|
</colgroup> |
|
<tbody> |
|
<tr class="odd"> |
|
<td style="text-align: right;"><span |
|
style="font-size:72pt">☞</span></td> |
|
<td><strong>Our method, QuIP#, creates 2 bit LLMs that achieve |
|
near-native performance, a previously unseen result. We provide a <a |
|
href="https://huggingface.co/relaxml">full suite of 2 bit Llama 1 and 2 |
|
models quantized using QuIP#</a>, as well as a full codebase that allows |
|
users to quantize and deploy their own models. We also provide CUDA |
|
kernels that accelerate inference for QuIP# models. Our code is |
|
available <a |
|
href="https://github.com/Cornell-RelaxML/quip-sharp">here</a>.</strong></td> |
|
</tr> |
|
</tbody> |
|
</table> |
|
</div> |
|
<h3 id="method-overview">Method Overview</h3> |
|
<p>QuIP# relies on two main components: <em>incoherence processing</em> |
|
and <em>lattice codebooks</em>. Incoherence processing in the context of |
|
model quantization was introduced in QuIP. While QuIP used a Kronecker |
|
product to perform incoherence processing, we introduce a Hadamard |
|
transform-based incoherence approach that is more amenable to fast GPU |
|
acceleration.</p> |
|
<p>Incoherence-processed weights are approximately Gaussian-distributed, |
|
which means that they are suitable for quantizing with symmetric and |
|
“round” codebooks. We introduce a new lattice codebook based on the |
|
<span class="math inline">\(E_8\)</span> lattice, which achieves the |
|
optimal 8 dimension unit ball packing density. Our codebooks are |
|
specifically designed to be hardware-friendly by exploiting symmetries |
|
in these lattices.</p> |
|
<h3 id="quantization-background">Quantization Background</h3> |
|
<p>In QuIP#, we follow existing state-of-the-art post training |
|
quantization methods and round weights to minimize the per-layer |
|
“adaptive rounding” proxy objective</p> |
|
<p><span class="math display">\[ |
|
\ell(\hat W) |
|
= E_x \left[ \| (\hat W - W)x \|^2 \right] |
|
= \operatorname{tr}\left( |
|
(\hat W - W) H (\hat W - W)^T |
|
\right). |
|
\]</span></p> |
|
<p>Here, <span class="math inline">\(W \in \mathbb{R}^{m \times |
|
n}\)</span> is the original weight matrix in a given layer, <span |
|
class="math inline">\(\hat W = \mathbb{R}^{m \times n}\)</span> are the |
|
quantized weights, <span class="math inline">\(x \in |
|
\mathbb{R}^n\)</span> is an input vector drawn uniformly at random from |
|
a calibration set, and <span class="math inline">\(H\)</span> is the |
|
second moment matrix of these vectors, interpreted as a proxy Hessian. |
|
This intra-layer formulation makes quantization tracatable for large |
|
language models. The original QuIP paper forumlated a class of adaptive |
|
rounding methods that used linear feedback to minimize <span |
|
class="math inline">\(\ell\)</span>. Within this class, the LDLQ |
|
rounding algorithm was shown to be optimal; we use LDLQ in QuIP# as |
|
well.</p> |
|
<h3 id="incoherence-processing">Incoherence Processing</h3> |
|
<p>The main insight of QuIP is that incoherent weight and hessian |
|
matrices result in improved quantization performance. Informally, this |
|
means that weights that are even in magnitude with important rounding |
|
directions (the Hessians) that are not too large in any one coordinate |
|
are significantly easier to quantize without catastrophic error. In some |
|
sense, incoherence processing can be viewed as a form of outlier |
|
suppression across weight and activation spaces.</p> |
|
<div style="background-color: #EEEEEE;"> |
|
<p><strong>Definition.</strong> <em>We say a symmetric Hessian matrix |
|
<span class="math inline">\(H \in \mathbb{R}^{n \times n}\)</span> is |
|
<span class="math inline">\(\mu\)</span>-incoherent if it has an |
|
eigendecomposition <span class="math inline">\(H = Q \Lambda |
|
Q^T\)</span> such that for all <span class="math inline">\(i\)</span> |
|
and <span class="math inline">\(j\)</span>, <span |
|
class="math inline">\(|Q_{ij}| = |e_i^T Q e_j| \leq \mu / |
|
\sqrt{n}\)</span>. By extension, we say a weight matrix <span |
|
class="math inline">\(W \in \mathbb{R}^{m \times n}\)</span> is <span |
|
class="math inline">\(\mu\)</span>-incoherent if for all <span |
|
class="math inline">\(i\)</span> and <span |
|
class="math inline">\(j\)</span>, <span class="math inline">\(|W_{ij}| = |
|
|e_i^T W e_j| \leq \mu \|W\|_F / \sqrt{mn}\)</span>.</em></p> |
|
</div> |
|
<p>Incoherence is an important property for quantizing models. In QuIP, |
|
the incoherence condition on <span class="math inline">\(H\)</span> is |
|
required to show that LDLQ achieves a superior proxy loss to nearest and |
|
stochastic rounding through a spectral bound on <span |
|
class="math inline">\(H\)</span>. Therefore, it is important to be able |
|
to incoherence-process weight and hessian matrices efficiently so that |
|
incoherence-processed models can be tractably deployed.</p> |
|
<p>One way to do this is by conjugating <span |
|
class="math inline">\(W\)</span> and <span |
|
class="math inline">\(H\)</span> by random orthogonal matrices. Let |
|
<span class="math inline">\(U \in \mathbb{R}^{m \times m}\)</span>, and |
|
<span class="math inline">\(V \in \mathbb{R}^{n \times n}\)</span> be |
|
two random orthogonal matrices. If we assign <span |
|
class="math inline">\(\tilde H \gets V H V^T\)</span> and <span |
|
class="math inline">\(\tilde W \gets U W V^T\)</span>, <span |
|
class="math inline">\(\tilde H\)</span> and <span |
|
class="math inline">\(\tilde W\)</span> become incoherence processed |
|
with high probability (see QuIP for proof). One can verify that this |
|
transformation preserves the proxy objective as <span |
|
class="math display">\[\operatorname{tr}(\tilde W \tilde H \tilde W^T) = |
|
\operatorname{tr}((U W V^T) (V H V^T) (V W^T U^T)) = |
|
\operatorname{tr}(WHW^T).\]</span></p> |
|
<h4 id="randomized-hadamard-transformation-rht">Randomized Hadamard |
|
Transformation (RHT)</h4> |
|
<p>To construct <span class="math inline">\(U\)</span> and <span |
|
class="math inline">\(V\)</span> from above, we use the RHT, which is |
|
amenable to fast GPU implementation. In fact, one of the CUDA sample |
|
kernels is the RHT. The RHT performs the multiplication <span |
|
class="math inline">\(x \in \mathbb{R}^n \to \mathbb{H}Sx\)</span>, |
|
where <span class="math inline">\(\mathbb{H}\)</span> is a <span |
|
class="math inline">\(n \times n\)</span> Hadamard matrix (scaled by a |
|
normalization factor) and <span class="math inline">\(S\)</span> is a |
|
<span class="math inline">\(n\)</span> dimensional random sign vector. |
|
The RHT concentrates the entries of <span |
|
class="math inline">\(x\)</span> and thus results in incoherent matrices |
|
through an <a |
|
href="http://www.cs.cmu.edu/afs/cs/user/dwoodruf/www/teaching/15859-fall20/lecture_2.1.pdf">application |
|
of the Azuma-Hoeffding inequality</a>. Note that the Hadamard transform |
|
can be computed more efficiently than a matrix multiplication via the |
|
fast Walsh-Hadamard transform, which we employ directly for powers of 2. |
|
To handle non power-of-two values of <span |
|
class="math inline">\(n\)</span>, we perform the following |
|
algorithm:</p> |
|
<ol type="1"> |
|
<li>Let <span class="math inline">\(p\)</span> be the remaining |
|
dimension and reshape <span class="math inline">\(Sx\)</span> into a |
|
<span class="math inline">\(n/p \times p\)</span> matrix.</li> |
|
<li>Perform the fast Walsh-Hadamard transform on <span |
|
class="math inline">\(Sx\)</span> associated with dimension <span |
|
class="math inline">\(n/p\)</span>.</li> |
|
<li>Let <span class="math inline">\(\mathbb{H}'\)</span> be a <span |
|
class="math inline">\(p \times p\)</span> scaled Hadamard matrix. Apply |
|
this Hadamard transform to <span class="math inline">\(Sx\)</span> on |
|
the right, and reshape back.</li> |
|
</ol> |
|
<p>The only consequence of performing RHT is needing to store two sign |
|
vectors per layer: <span class="math inline">\(S_U\)</span> and <span |
|
class="math inline">\(S_V\)</span>. Since large language models have |
|
large weight and Hessian matrices, this only increases the number of |
|
bits per weight in practice by less than 0.01, or a negligible |
|
amount.</p> |
|
<h3 id="lattice-codebooks">Lattice Codebooks</h3> |
|
<p>Incoherence processed weights are approximately Gaussian-distributed, |
|
meaning that they are symmetric and “round.” To take advantage of this |
|
“roundness,” we can use <span class="math inline">\(n\)</span> |
|
dimensional codebooks that quantize <span |
|
class="math inline">\(n\)</span> weights at once. Specifically, to |
|
quantize <span class="math inline">\(x \in \mathbb{R}^n\)</span> to a |
|
<span class="math inline">\(n\)</span> dimensional codebook <span |
|
class="math inline">\(C \in \mathbb{R}^{m \times n}\)</span>, we round |
|
<span class="math inline">\(x\)</span> to its nearest distance-wise |
|
entry in <span class="math inline">\(C\)</span>. This requires <span |
|
class="math inline">\(\log_2m\)</span> bits to represent which index in |
|
<span class="math inline">\(C\)</span> to store, and results in <span |
|
class="math inline">\(k = \frac{\log_2m}{n}\)</span> bits per |
|
weight.</p> |
|
<p>Increasing <span class="math inline">\(n\)</span> results in a |
|
“rounder” codebook that reduces quantization error. However, note that |
|
the number of bits per weight is determined by <em>both</em> the number |
|
of entries in <span class="math inline">\(C\)</span> (m) as well as the |
|
dimension of <span class="math inline">\(C\)</span> (n). To maintain a |
|
set number of bits per weight, a linear increase in <span |
|
class="math inline">\(n\)</span> requires an exponential increase in |
|
<span class="math inline">\(m\)</span>. For example, a naively designed |
|
16-dimensional codebook requires <span |
|
class="math inline">\(2^{32}\)</span> entries to achieve 2 bits per |
|
weight, but performing lookups into a size <span |
|
class="math inline">\(2^{32}\)</span> codebook is intractable. Thus, it |
|
is important to design codebooks that both have relatively large <span |
|
class="math inline">\(n\)</span> while being compressible so the actual |
|
lookup happens with less than <span |
|
class="math inline">\(2^{nk}\)</span> entries.</p> |
|
<p>Geometric lattices are suitable bases for such codebooks as most |
|
lattices have inherent symmetries and certain lattices achieve optimal |
|
bin packing densities. For example, our E8P codebook based on the <span |
|
class="math inline">\(E_8\)</span> lattice has <span |
|
class="math inline">\(2^{16}\)</span> entries but only requires looking |
|
up into a size <span class="math inline">\(2^8\)</span> codebook due to |
|
symmetries inherent to the <span class="math inline">\(E_8\)</span> |
|
lattice itself – more on this later. In QuIP#, we present the E8P |
|
codebook based on the 8-dimensional <span |
|
class="math inline">\(E_8\)</span> lattice. This lattice achieves the 8 |
|
dimensional kissing number, or the maximum number of unit balls touching |
|
a central unit ball in 8 dimensions. Interestingly, Maryna Viazovska |
|
recently won the Fields Medal in 2022 “for the proof that the <span |
|
class="math inline">\(E_8\)</span> lattice provides the densest packing |
|
of identical spheres in 8 dimensions.”</p> |
|
<figure> |
|
<img src="img/kissing2d.png" |
|
alt="The 2D kissing number is 6, which is achieved by this packing configuration. Image from Wikipedia." /> |
|
<figcaption aria-hidden="true">The 2D kissing number is 6, which is |
|
achieved by this packing configuration. Image from |
|
Wikipedia.</figcaption> |
|
</figure> |
|
<h4 id="e8p-codebook">E8P Codebook</h4> |
|
<p>Our E8P codebook is a version of the <span |
|
class="math inline">\(E_8\)</span> lattice intersected with a ball, |
|
padded (hence the P in E8P) to reach <span |
|
class="math inline">\(2^{16}\)</span> entries. This results in <span |
|
class="math inline">\(k = 16/8 = 2\)</span> bits per weight. The <span |
|
class="math inline">\(E_8\)</span> lattice is composed of 8 dimensional |
|
all-integer or all-half integer vectors whose sum is an even number. In |
|
set-builder notation, <span class="math display">\[E_8 = \left\{x \mid x |
|
\in \left(\mathbb{Z}^8 \cup \left(\mathbb{Z}+\frac{1}{2}\right)^8\right) |
|
\land \sum_i x_i \equiv 0 \pmod 2\right\}.\]</span> Note that <span |
|
class="math inline">\(E_8 + \frac{1}{4}\)</span> has the same packing |
|
density of <span class="math inline">\(E_8\)</span> and is equivalent to |
|
<span class="math inline">\(D_8 + \frac{1}{2} \pm \frac{1}{4}\)</span>, |
|
where <span class="math inline">\(D_8\)</span> is the set of 8 |
|
dimensional all-integer vectors with even sum. Denote <span |
|
class="math inline">\(D_8 + \frac{1}{2}\)</span> as <span |
|
class="math inline">\(\hat{D_8}\)</span>; all elements in <span |
|
class="math inline">\(\hat{D_8}\)</span> also have even sum parity.</p> |
|
<p>Now, note that if we flip an even number of signs of an element in |
|
<span class="math inline">\(\hat{D_8}\)</span>, we get another element |
|
in <span class="math inline">\(\hat{D_8}\)</span>, whereas flipping an |
|
odd number of signs results in something not in <span |
|
class="math inline">\(\hat{D_8}\)</span>. This is due to <span |
|
class="math inline">\(\hat{D_8}\)</span> being a half integer grid; |
|
flipping a single half integer results in changing the sum parity. Since |
|
<span class="math inline">\(\hat{D_8}\)</span> has 8 dimensions, there |
|
are <span class="math inline">\(2^8/2 = 128\)</span> possible “even sign |
|
flip” patterns to stay within <span |
|
class="math inline">\(\hat{D_8}\)</span>. Conversely, there are also 128 |
|
“odd sign flip” patterns.</p> |
|
<p>If we start from some “source codebook” <span |
|
class="math inline">\(S\)</span> that is a subset of <span |
|
class="math inline">\(|\hat{D_8}|\)</span>, where <span |
|
class="math inline">\(|\cdot|\)</span> denotes the elementwise absolute |
|
value, we can use 128 odd or even sign flips to generate a subset of |
|
<span class="math inline">\(\hat{D_8}\)</span>. Each entry in <span |
|
class="math inline">\(S\)</span> is either an odd or even number of |
|
flips away from an entry in <span |
|
class="math inline">\(\hat{D_8}\)</span>, but not both. Thus, given an |
|
entry <span class="math inline">\(s \in S\)</span> and 7 out of the 8 |
|
sign flips, we can infer the last one from the parity of the 7 sign |
|
flips and <span class="math inline">\(s\)</span>. This lets us use the |
|
following bit pattern to store a 16-bit codeword in <span |
|
class="math inline">\(E_8 + \frac{1}{4}\)</span>: 8 bits for the entry |
|
index in <span class="math inline">\(S\)</span>, 7 bits for the sign |
|
flips of the right 7 dimensions of the entry in <span |
|
class="math inline">\(S\)</span>, and 1 bit to add or subtract <span |
|
class="math inline">\(\frac{1}{4}\)</span>.</p> |
|
<p>For example, if we had the codeword <code>0001010110010111</code>, |
|
the first 8 bits <code>00010101</code> = 21 would indicate that we start |
|
with the 21st entry in <span class="math inline">\(S\)</span>. In this |
|
example, let this be the vector</p> |
|
<p><span class="math display">\[\left\{\frac{1}{2}, \frac{1}{2}, |
|
\frac{1}{2}, \frac{3}{2}, \frac{1}{2}, \frac{1}{2}, \frac{1}{2}, |
|
\frac{1}{2}\right\},\]</span></p> |
|
<p>which is not in <span class="math inline">\(\hat{D_8}\)</span>. Thus, |
|
<span class="math inline">\(s\)</span> requires an odd number of sign |
|
flips to get into <span class="math inline">\(\hat{D_8}\)</span>. Then, |
|
the next 7 bits <code>1001011</code> would indicate that we need to |
|
negate the 1st, 2nd, 4th, and 7th from right bits. Since we need an odd |
|
number of sign flips, the 8th from right bit is also a sign flip. The |
|
sign-decoded vector is then</p> |
|
<p><span class="math display">\[\left\{-\frac{1}{2}, -\frac{1}{2}, |
|
\frac{1}{2}, \frac{3}{2}, -\frac{1}{2}, \frac{1}{2}, -\frac{1}{2}, |
|
-\frac{1}{2}\right\},\]</span></p> |
|
<p>which we can verify is in <span class="math inline">\(E_8\)</span>. |
|
Finally, the last bit <code>1</code> indicates that we need to add <span |
|
class="math inline">\(\frac{1}{4}\)</span>, so the final decoded vector |
|
is</p> |
|
<p><span class="math display">\[\left\{-\frac{1}{4}, -\frac{3}{4}, |
|
\frac{3}{4}, \frac{7}{4}, -\frac{1}{4}, \frac{3}{4}, -\frac{1}{4}, |
|
-\frac{1}{4}\right\},\]</span></p> |
|
<p>which is in <span class="math inline">\(E_8 + \frac{1}{4}\)</span> as |
|
desired.</p> |
|
<p>Putting this all together, this lets us decode a size <span |
|
class="math inline">\(2^{16}\)</span> codebook by looking up into only a |
|
size <span class="math inline">\(2^8\)</span> codebook (namely <span |
|
class="math inline">\(S\)</span>) and performing some operations. On |
|
hardware, this means that we can store the smaller <span |
|
class="math inline">\(2^8\)</span> codebook in local caches, avoiding |
|
performance killing memory accesses that the larger <span |
|
class="math inline">\(2^{16}\)</span> codebook would require. The |
|
question remains then of how to choose <span |
|
class="math inline">\(S\)</span>. In our implementation, we set <span |
|
class="math inline">\(S\)</span> to be the 227 elements of <span |
|
class="math inline">\(|\hat{D_8}|\)</span> with norm <span |
|
class="math inline">\(\le \sqrt{10}\)</span> plus 29 elements from <span |
|
class="math inline">\(|\hat{D_8}|\)</span> that have norm <span |
|
class="math inline">\(\sqrt{12}\)</span>. The exact elements chosen can |
|
be found in our code.</p> |
|
<h4 id="codebook-errors">Codebook Errors</h4> |
|
<p>To show the optimality of our lattice codebooks, we plotted the |
|
minimum achievable elementwise MSE of quantizing a <span |
|
class="math inline">\(n\)</span>-dimensional multivariate Gaussian to |
|
various <span class="math inline">\(k\)</span> bit codebooks. To create |
|
each codebook, we intersected a ball with the base lattice and increased |
|
the radius of the ball to get more bits. The half integer codebooks are |
|
formed from the <span class="math inline">\(n\)</span>-dimensional half |
|
integer grids.</p> |
|
<p>Specifically, each point in the graph below plots <span |
|
class="math inline">\((k, y)\)</span> where</p> |
|
<p><span class="math display">\[y = \min_{s \in \mathbb{R}^+} |
|
\frac{1}{n}\left\|\mbox{quantize}\left(\frac{\mathcal{N}(0_n, I_n)}{s}, |
|
\mbox{codebook}\right)*s - \mathcal{N}(0_n, I_n)\right\|^2\]</span></p> |
|
<figure> |
|
<img src="img/lattice_err.png" title="Lattice Errors" |
|
alt="Lowest element-wise mean squared error (MSE) achievable for quantizing a multivariate Gaussian to various codebooks. The E_8 lattice achieves the densest unit-sphere packing in 8 dimensions and our derivative codebooks have the lowest MSE." /> |
|
<figcaption aria-hidden="true">Lowest element-wise mean squared error |
|
(MSE) achievable for quantizing a multivariate Gaussian to various |
|
codebooks. The <span class="math inline">\(E_8\)</span> lattice achieves |
|
the <a href="https://en.wikipedia.org/wiki/Kissing_number">densest |
|
unit-sphere packing in 8 dimensions</a> and our derivative codebooks |
|
have the lowest MSE.</figcaption> |
|
</figure> |
|
<p>The <span class="math inline">\(E_8\)</span>-based codebooks achieves |
|
lower MSEs than all other codebooks, including those based on the <span |
|
class="math inline">\(D_4\)</span> lattice that achieves the 4 |
|
dimensional kissing number. This figure also shows the importance of |
|
having a large number of columns <span class="math inline">\(n\)</span>. |
|
Increasing the number of columns decreases the error for the half |
|
integer grid, as the resulting codebook is more “round.” Since the E8P |
|
codebook is actually the union of two shifted codebooks, each of which |
|
is a ball intersected with <span |
|
class="math inline">\(\hat{D_8}\)</span>, it is not perfectly round. |
|
This is reflected in the MSE plot, where it sits slightly above the |
|
<span class="math inline">\(E_8\)</span> line. However, there does not |
|
exist a <span class="math inline">\(E_8\)</span> codebook with exactly 2 |
|
bits, so E8P is still practically superior.</p> |
|
<h3 id="results">Results</h3> |
|
<p>Here, we present quantization results using QuIP# on Llama 1 and 2 |
|
models. All models were quantized using the Hadamard transform for |
|
incoherence processing and a weight scale factor of roughly 0.9 times |
|
the optimal scale for a multivariate Gaussian to compensate for |
|
inter-layer interactions. Furthermore, all Llama 2 models were evaluated |
|
using a context lenth of 4096 and all Llama 1 models were evaluated with |
|
context length 2048; these numbers match the context length the models |
|
were trained with. These and other models can be found in our <a |
|
href="https://huggingface.co/relaxml">Hugging Face repository</a>.</p> |
|
<p>The table below contains results for all Llama 1 and 2 models when |
|
quantized to 2 bits using the E8P codebook. QuIP# achieves excellent |
|
performance across all model sizes on both language modeling and zero |
|
shot tasks. Furthermore, on the zero-shot tasks (ArcC, ArcE, BoolQ, |
|
PiQA, WinoGrande), QuIP# models achieve near-native performance with |
|
minimal degradation. Additional results are available <a |
|
href="https://docs.google.com/spreadsheets/d/18woLrIBdVGUr9CuFDbK9pl_6QzEBl09hfnoe4Qkg7Hg/edit?usp=sharing">here</a>.</p> |
|
<div style="margin-left: -6%; |
|
margin-right: auto; |
|
width: 112%;"> |
|
<table> |
|
<caption>QuIP# results across all Llama 1 and 2 models. QuIP# achieves |
|
near-native performance at 2 bits on language modeling (C4, Wiki) and |
|
zero shot (ArcC, ArcE, BoolQ, PiQA, WinoGrande) tasks.</caption> |
|
<colgroup> |
|
<col style="width: 6%" /> |
|
<col style="width: 6%" /> |
|
<col style="width: 10%" /> |
|
<col style="width: 11%" /> |
|
<col style="width: 10%" /> |
|
<col style="width: 10%" /> |
|
<col style="width: 12%" /> |
|
<col style="width: 10%" /> |
|
<col style="width: 20%" /> |
|
</colgroup> |
|
<thead> |
|
<tr class="header"> |
|
<th style="text-align: center;">Model</th> |
|
<th style="text-align: center;">Method</th> |
|
<th style="text-align: center;">C4 <span |
|
class="math inline">\(\downarrow\)</span></th> |
|
<th style="text-align: center;">Wiki <span |
|
class="math inline">\(\downarrow\)</span></th> |
|
<th style="text-align: center;">ArcC <span |
|
class="math inline">\(\uparrow\)</span></th> |
|
<th style="text-align: center;">ArcE <span |
|
class="math inline">\(\uparrow\)</span></th> |
|
<th style="text-align: center;">BoolQ <span |
|
class="math inline">\(\uparrow\)</span></th> |
|
<th style="text-align: center;">PiQA <span |
|
class="math inline">\(\uparrow\)</span></th> |
|
<th style="text-align: center;">WinoGrande <span |
|
class="math inline">\(\uparrow\)</span></th> |
|
</tr> |
|
</thead> |
|
<tbody> |
|
<tr class="odd"> |
|
<td style="text-align: center;">2-70B</td> |
|
<td style="text-align: center;">fp16</td> |
|
<td style="text-align: center;">5.533</td> |
|
<td style="text-align: center;">3.120</td> |
|
<td style="text-align: center;">0.480</td> |
|
<td style="text-align: center;">0.597</td> |
|
<td style="text-align: center;">0.766</td> |
|
<td style="text-align: center;">0.809</td> |
|
<td style="text-align: center;">0.768</td> |
|
</tr> |
|
<tr class="even"> |
|
<td style="text-align: center;">2-70B</td> |
|
<td style="text-align: center;">QuIP#</td> |
|
<td style="text-align: center;">6.529</td> |
|
<td style="text-align: center;">4.158</td> |
|
<td style="text-align: center;">0.472</td> |
|
<td style="text-align: center;">0.595</td> |
|
<td style="text-align: center;">0.791</td> |
|
<td style="text-align: center;">0.786</td> |
|
<td style="text-align: center;">0.742</td> |
|
</tr> |
|
<tr class="odd"> |
|
<td style="text-align: center;">2-13B</td> |
|
<td style="text-align: center;">fp16</td> |
|
<td style="text-align: center;">6.520</td> |
|
<td style="text-align: center;">4.574</td> |
|
<td style="text-align: center;">0.443</td> |
|
<td style="text-align: center;">0.580</td> |
|
<td style="text-align: center;">0.690</td> |
|
<td style="text-align: center;">0.790</td> |
|
<td style="text-align: center;">0.699</td> |
|
</tr> |
|
<tr class="even"> |
|
<td style="text-align: center;">2-13B</td> |
|
<td style="text-align: center;">QuIP#</td> |
|
<td style="text-align: center;">8.755</td> |
|
<td style="text-align: center;">6.058</td> |
|
<td style="text-align: center;">0.371</td> |
|
<td style="text-align: center;">0.501</td> |
|
<td style="text-align: center;">0.665</td> |
|
<td style="text-align: center;">0.757</td> |
|
<td style="text-align: center;">0.636</td> |
|
</tr> |
|
<tr class="odd"> |
|
<td style="text-align: center;">2-7B</td> |
|
<td style="text-align: center;">fp16</td> |
|
<td style="text-align: center;">7.036</td> |
|
<td style="text-align: center;">5.116</td> |
|
<td style="text-align: center;">0.406</td> |
|
<td style="text-align: center;">0.535</td> |
|
<td style="text-align: center;">0.710</td> |
|
<td style="text-align: center;">0.769</td> |
|
<td style="text-align: center;">0.670</td> |
|
</tr> |
|
<tr class="even"> |
|
<td style="text-align: center;">2-7B</td> |
|
<td style="text-align: center;">QuIP#</td> |
|
<td style="text-align: center;">12.062</td> |
|
<td style="text-align: center;">8.224</td> |
|
<td style="text-align: center;">0.325</td> |
|
<td style="text-align: center;">0.428</td> |
|
<td style="text-align: center;">0.623</td> |
|
<td style="text-align: center;">0.712</td> |
|
<td style="text-align: center;">0.624</td> |
|
</tr> |
|
<tr class="odd"> |
|
<td style="text-align: center;">1-65b</td> |
|
<td style="text-align: center;">fp16</td> |
|
<td style="text-align: center;">5.811</td> |
|
<td style="text-align: center;">3.532</td> |
|
<td style="text-align: center;">0.463</td> |
|
<td style="text-align: center;">0.588</td> |
|
<td style="text-align: center;">0.823</td> |
|
<td style="text-align: center;">0.809</td> |
|
<td style="text-align: center;">0.771</td> |
|
</tr> |
|
<tr class="even"> |
|
<td style="text-align: center;">1-65b</td> |
|
<td style="text-align: center;">QuIP#</td> |
|
<td style="text-align: center;">6.744</td> |
|
<td style="text-align: center;">4.566</td> |
|
<td style="text-align: center;">0.436</td> |
|
<td style="text-align: center;">0.569</td> |
|
<td style="text-align: center;">0.817</td> |
|
<td style="text-align: center;">0.805</td> |
|
<td style="text-align: center;">0.736</td> |
|
</tr> |
|
<tr class="odd"> |
|
<td style="text-align: center;">1-30B</td> |
|
<td style="text-align: center;">fp16</td> |
|
<td style="text-align: center;">6.130</td> |
|
<td style="text-align: center;">4.101</td> |
|
<td style="text-align: center;">0.453</td> |
|
<td style="text-align: center;">0.590</td> |
|
<td style="text-align: center;">0.684</td> |
|
<td style="text-align: center;">0.801</td> |
|
<td style="text-align: center;">0.728</td> |
|
</tr> |
|
<tr class="even"> |
|
<td style="text-align: center;">1-30B</td> |
|
<td style="text-align: center;">QuIP#</td> |
|
<td style="text-align: center;">7.471</td> |
|
<td style="text-align: center;">5.317</td> |
|
<td style="text-align: center;">0.429</td> |
|
<td style="text-align: center;">0.545</td> |
|
<td style="text-align: center;">0.669</td> |
|
<td style="text-align: center;">0.779</td> |
|
<td style="text-align: center;">0.718</td> |
|
</tr> |
|
<tr class="odd"> |
|
<td style="text-align: center;">1-13B</td> |
|
<td style="text-align: center;">fp16</td> |
|
<td style="text-align: center;">6.798</td> |
|
<td style="text-align: center;">5.091</td> |
|
<td style="text-align: center;">0.444</td> |
|
<td style="text-align: center;">0.599</td> |
|
<td style="text-align: center;">0.684</td> |
|
<td style="text-align: center;">0.792</td> |
|
<td style="text-align: center;">0.701</td> |
|
</tr> |
|
<tr class="even"> |
|
<td style="text-align: center;">1-13B</td> |
|
<td style="text-align: center;">QuIP#</td> |
|
<td style="text-align: center;">8.425</td> |
|
<td style="text-align: center;">6.381</td> |
|
<td style="text-align: center;">0.387</td> |
|
<td style="text-align: center;">0.536</td> |
|
<td style="text-align: center;">0.647</td> |
|
<td style="text-align: center;">0.750</td> |
|
<td style="text-align: center;">0.669</td> |
|
</tr> |
|
<tr class="odd"> |
|
<td style="text-align: center;">1-7B</td> |
|
<td style="text-align: center;">fp16</td> |
|
<td style="text-align: center;">7.343</td> |
|
<td style="text-align: center;">5.677</td> |
|
<td style="text-align: center;">0.415</td> |
|
<td style="text-align: center;">0.525</td> |
|
<td style="text-align: center;">0.731</td> |
|
<td style="text-align: center;">0.774</td> |
|
<td style="text-align: center;">0.670</td> |
|
</tr> |
|
<tr class="even"> |
|
<td style="text-align: center;">1-7B</td> |
|
<td style="text-align: center;">QuIP#</td> |
|
<td style="text-align: center;">10.970</td> |
|
<td style="text-align: center;">8.286</td> |
|
<td style="text-align: center;">0.352</td> |
|
<td style="text-align: center;">0.464</td> |
|
<td style="text-align: center;">0.647</td> |
|
<td style="text-align: center;">0.720</td> |
|
<td style="text-align: center;">0.624</td> |
|
</tr> |
|
</tbody> |
|
</table> |
|
</div> |
|
</body> |
|
</html> |
|
|