KnutJaegersberg's picture
Upload 91 files
b3c0032
raw
history blame
35.8 kB
<!DOCTYPE html>
<html xmlns="http://www.w3.org/1999/xhtml" lang="" xml:lang="">
<head>
<meta charset="utf-8" />
<meta name="generator" content="pandoc" />
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes" />
<title>QuIP#</title>
<style>
html {
font-family: cambria;
font-size: 16pt;
color: #1a1a1a;
background-color: #fdfdfd;
}
body {
margin: 0 auto;
max-width: 36em;
padding-left: 50px;
padding-right: 50px;
padding-top: 50px;
padding-bottom: 50px;
hyphens: auto;
overflow-wrap: break-word;
text-rendering: optimizeLegibility;
font-kerning: normal;
}
@media (max-width: 600px) {
body {
font-size: 0.9em;
padding: 12px;
}
h1 {
font-size: 1.8em;
}
}
@media print {
html {
background-color: white;
}
body {
background-color: transparent;
color: black;
font-size: 12pt;
}
p, h2, h3 {
orphans: 3;
widows: 3;
}
h2, h3, h4 {
page-break-after: avoid;
}
}
p {
margin: 1em 0;
}
a {
color: #1a1a1a;
}
a:visited {
color: #1a1a1a;
}
img {
max-width: 100%;
}
svg {
height: auto;
max-width: 100%;
}
h1, h2, h3, h4, h5, h6 {
margin-top: 1.4em;
}
h5, h6 {
font-size: 1em;
font-style: italic;
}
h6 {
font-weight: normal;
}
ol, ul {
padding-left: 1.7em;
margin-top: 1em;
}
li > ol, li > ul {
margin-top: 0;
}
blockquote {
margin: 1em 0 1em 1.7em;
padding-left: 1em;
border-left: 2px solid #e6e6e6;
color: #606060;
}
code {
font-family: Menlo, Monaco, Consolas, 'Lucida Console', monospace;
font-size: 85%;
margin: 0;
hyphens: manual;
}
pre {
margin: 1em 0;
overflow: auto;
}
pre code {
padding: 0;
overflow: visible;
overflow-wrap: normal;
}
.sourceCode {
background-color: transparent;
overflow: visible;
}
hr {
background-color: #1a1a1a;
border: none;
height: 1px;
margin: 1em 0;
}
table {
margin: 1em 0;
border-collapse: collapse;
width: 100%;
overflow-x: auto;
display: block;
font-variant-numeric: lining-nums tabular-nums;
}
table caption {
margin-bottom: 0.75em;
}
tbody {
margin-top: 0.5em;
border-top: 1px solid #1a1a1a;
border-bottom: 1px solid #1a1a1a;
}
th {
border-top: 1px solid #1a1a1a;
padding: 0.25em 0.5em 0.25em 0.5em;
}
td {
padding: 0.125em 0.5em 0.25em 0.5em;
}
header {
margin-bottom: 4em;
text-align: center;
}
#TOC li {
list-style: none;
}
#TOC ul {
padding-left: 1.3em;
}
#TOC > ul {
padding-left: 0;
}
#TOC a:not(:hover) {
text-decoration: none;
}
code{white-space: pre-wrap;}
span.smallcaps{font-variant: small-caps;}
div.columns{display: flex; gap: min(4vw, 1.5em);}
div.column{flex: auto; overflow-x: auto;}
div.hanging-indent{margin-left: 1.5em; text-indent: -1.5em;}
/* The extra [class] is a hack that increases specificity enough to
override a similar rule in reveal.js */
ul.task-list[class]{list-style: none;}
ul.task-list li input[type="checkbox"] {
font-size: inherit;
width: 0.8em;
margin: 0 0.8em 0.2em -1.6em;
vertical-align: middle;
}
</style>
<script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
<script
src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-chtml-full.js"
type="text/javascript"></script>
<!--[if lt IE 9]>
<script src="//cdnjs.cloudflare.com/ajax/libs/html5shiv/3.7.3/html5shiv-printshiv.min.js"></script>
<![endif]-->
</head>
<body>
<style>
body { max-width: 800px !important; text-align: justify; }
tbody {
border-top: none;
border-bottom: none;
}
header { height:0px;}
tr:nth-child(2n) {
background-color: #EEEEEE;
}
th {
background-color: #EEEEEE;
}
</style>
<h2 id="quip-quip-with-lattice-codebooks">QuIP#: <a
href="https://github.com/jerry-chee/QuIP">QuIP</a> with Lattice
Codebooks</h2>
<p><a href="https://tsengalb99.github.io">Albert Tseng*</a>, <a
href="https://jerry-chee.github.io/">Jerry Chee*</a>, <a
href="https://nalzok.github.io/">Qingyao Sun</a>, <a
href="https://www.cs.cornell.edu/~kuleshov/">Volodymyr Kuleshov</a>, and
<a href="https://www.cs.cornell.edu/~cdesa/">Chris De Sa</a></p>
<hr />
<p><img src="img/overview.svg" /></p>
<p>Large language models (LLMs) exhibit amazing performance on a wide
variety of tasks such as text modeling and code generation. However,
they are also very large. For example Llama 2 70B has 70 billion
parameters that require 140GB of memory to store in half precision. This
presents many challenges, such as needing multiple GPUs just to serve a
single LLM. To address these issues, researchers have developed
compression methods that reduce the size of models without destroying
performance.</p>
<p>One class of methods, post-training quantization, compresses trained
model weights into lower precision formats to reduce memory
requirements. For example, quantizing a model from 16 bit to 2 bit
precision would reduce the size of the model by 8x, meaning that even
Llama 2 70B would fit on a single 24GB GPU. In this work, we introduce
<strong>QuIP#</strong>, which combines lattice codebooks with
incoherence processing to create state-of-the-art 2 bit quantized
models. These two methods allow QuIP# to significantly close the gap
between 2 bit quantized LLMs and unquantized 16 bit models.</p>
<div style="margin-left: auto;
margin-right: auto;
width: 90%;">
<table style="width:100%;">
<caption>Quantization results on Llama 2 70B. QuIP# achieves near-native
performance at 2 bits, outperforming all other presented
baselines.</caption>
<colgroup>
<col style="width: 16%" />
<col style="width: 16%" />
<col style="width: 16%" />
<col style="width: 16%" />
<col style="width: 16%" />
<col style="width: 16%" />
</colgroup>
<thead>
<tr class="header">
<th style="text-align: center;">Method</th>
<th style="text-align: center;">Precision</th>
<th style="text-align: center;">Wiki <span
class="math inline">\(\downarrow\)</span></th>
<th style="text-align: center;">C4 <span
class="math inline">\(\downarrow\)</span></th>
<th style="text-align: center;">ArcE <span
class="math inline">\(\uparrow\)</span></th>
<th style="text-align: center;">PiQA <span
class="math inline">\(\uparrow\)</span></th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td style="text-align: center;">Native</td>
<td style="text-align: center;">16 bit</td>
<td style="text-align: center;">3.120</td>
<td style="text-align: center;">5.533</td>
<td style="text-align: center;">0.597</td>
<td style="text-align: center;">0.809</td>
</tr>
<tr class="even">
<td style="text-align: center;">OPTQ</td>
<td style="text-align: center;">3 bit</td>
<td style="text-align: center;">4.577</td>
<td style="text-align: center;">6.838</td>
<td style="text-align: center;">0.544</td>
<td style="text-align: center;"><strong>0.786</strong></td>
</tr>
<tr class="odd">
<td style="text-align: center;">OPTQ</td>
<td style="text-align: center;">2 bit</td>
<td style="text-align: center;">109.820</td>
<td style="text-align: center;">62.692</td>
<td style="text-align: center;">0.253</td>
<td style="text-align: center;">0.505</td>
</tr>
<tr class="even">
<td style="text-align: center;">QuIP</td>
<td style="text-align: center;">2 bit</td>
<td style="text-align: center;">5.574</td>
<td style="text-align: center;">8.268</td>
<td style="text-align: center;">0.544</td>
<td style="text-align: center;">0.751</td>
</tr>
<tr class="odd">
<td style="text-align: center;"><strong>QuIP#</strong></td>
<td style="text-align: center;"><strong>2 bit</strong></td>
<td style="text-align: center;"><strong>4.159</strong></td>
<td style="text-align: center;"><strong>6.529</strong></td>
<td style="text-align: center;"><strong>0.595</strong></td>
<td style="text-align: center;">0.786</td>
</tr>
</tbody>
</table>
</div>
<div
style="color:steelblue; margin-left: -14%; margin-right: auto; width: 115%">
<table>
<colgroup>
<col style="width: 3%" />
<col style="width: 96%" />
</colgroup>
<tbody>
<tr class="odd">
<td style="text-align: right;"><span
style="font-size:72pt"></span></td>
<td><strong>Our method, QuIP#, creates 2 bit LLMs that achieve
near-native performance, a previously unseen result. We provide a <a
href="https://huggingface.co/relaxml">full suite of 2 bit Llama 1 and 2
models quantized using QuIP#</a>, as well as a full codebase that allows
users to quantize and deploy their own models. We also provide CUDA
kernels that accelerate inference for QuIP# models. Our code is
available <a
href="https://github.com/Cornell-RelaxML/quip-sharp">here</a>.</strong></td>
</tr>
</tbody>
</table>
</div>
<h3 id="method-overview">Method Overview</h3>
<p>QuIP# relies on two main components: <em>incoherence processing</em>
and <em>lattice codebooks</em>. Incoherence processing in the context of
model quantization was introduced in QuIP. While QuIP used a Kronecker
product to perform incoherence processing, we introduce a Hadamard
transform-based incoherence approach that is more amenable to fast GPU
acceleration.</p>
<p>Incoherence-processed weights are approximately Gaussian-distributed,
which means that they are suitable for quantizing with symmetric and
“round” codebooks. We introduce a new lattice codebook based on the
<span class="math inline">\(E_8\)</span> lattice, which achieves the
optimal 8 dimension unit ball packing density. Our codebooks are
specifically designed to be hardware-friendly by exploiting symmetries
in these lattices.</p>
<h3 id="quantization-background">Quantization Background</h3>
<p>In QuIP#, we follow existing state-of-the-art post training
quantization methods and round weights to minimize the per-layer
“adaptive rounding” proxy objective</p>
<p><span class="math display">\[
\ell(\hat W)
= E_x \left[ \| (\hat W - W)x \|^2 \right]
= \operatorname{tr}\left(
(\hat W - W) H (\hat W - W)^T
\right).
\]</span></p>
<p>Here, <span class="math inline">\(W \in \mathbb{R}^{m \times
n}\)</span> is the original weight matrix in a given layer, <span
class="math inline">\(\hat W = \mathbb{R}^{m \times n}\)</span> are the
quantized weights, <span class="math inline">\(x \in
\mathbb{R}^n\)</span> is an input vector drawn uniformly at random from
a calibration set, and <span class="math inline">\(H\)</span> is the
second moment matrix of these vectors, interpreted as a proxy Hessian.
This intra-layer formulation makes quantization tracatable for large
language models. The original QuIP paper forumlated a class of adaptive
rounding methods that used linear feedback to minimize <span
class="math inline">\(\ell\)</span>. Within this class, the LDLQ
rounding algorithm was shown to be optimal; we use LDLQ in QuIP# as
well.</p>
<h3 id="incoherence-processing">Incoherence Processing</h3>
<p>The main insight of QuIP is that incoherent weight and hessian
matrices result in improved quantization performance. Informally, this
means that weights that are even in magnitude with important rounding
directions (the Hessians) that are not too large in any one coordinate
are significantly easier to quantize without catastrophic error. In some
sense, incoherence processing can be viewed as a form of outlier
suppression across weight and activation spaces.</p>
<div style="background-color: #EEEEEE;">
<p><strong>Definition.</strong> <em>We say a symmetric Hessian matrix
<span class="math inline">\(H \in \mathbb{R}^{n \times n}\)</span> is
<span class="math inline">\(\mu\)</span>-incoherent if it has an
eigendecomposition <span class="math inline">\(H = Q \Lambda
Q^T\)</span> such that for all <span class="math inline">\(i\)</span>
and <span class="math inline">\(j\)</span>, <span
class="math inline">\(|Q_{ij}| = |e_i^T Q e_j| \leq \mu /
\sqrt{n}\)</span>. By extension, we say a weight matrix <span
class="math inline">\(W \in \mathbb{R}^{m \times n}\)</span> is <span
class="math inline">\(\mu\)</span>-incoherent if for all <span
class="math inline">\(i\)</span> and <span
class="math inline">\(j\)</span>, <span class="math inline">\(|W_{ij}| =
|e_i^T W e_j| \leq \mu \|W\|_F / \sqrt{mn}\)</span>.</em></p>
</div>
<p>Incoherence is an important property for quantizing models. In QuIP,
the incoherence condition on <span class="math inline">\(H\)</span> is
required to show that LDLQ achieves a superior proxy loss to nearest and
stochastic rounding through a spectral bound on <span
class="math inline">\(H\)</span>. Therefore, it is important to be able
to incoherence-process weight and hessian matrices efficiently so that
incoherence-processed models can be tractably deployed.</p>
<p>One way to do this is by conjugating <span
class="math inline">\(W\)</span> and <span
class="math inline">\(H\)</span> by random orthogonal matrices. Let
<span class="math inline">\(U \in \mathbb{R}^{m \times m}\)</span>, and
<span class="math inline">\(V \in \mathbb{R}^{n \times n}\)</span> be
two random orthogonal matrices. If we assign <span
class="math inline">\(\tilde H \gets V H V^T\)</span> and <span
class="math inline">\(\tilde W \gets U W V^T\)</span>, <span
class="math inline">\(\tilde H\)</span> and <span
class="math inline">\(\tilde W\)</span> become incoherence processed
with high probability (see QuIP for proof). One can verify that this
transformation preserves the proxy objective as <span
class="math display">\[\operatorname{tr}(\tilde W \tilde H \tilde W^T) =
\operatorname{tr}((U W V^T) (V H V^T) (V W^T U^T)) =
\operatorname{tr}(WHW^T).\]</span></p>
<h4 id="randomized-hadamard-transformation-rht">Randomized Hadamard
Transformation (RHT)</h4>
<p>To construct <span class="math inline">\(U\)</span> and <span
class="math inline">\(V\)</span> from above, we use the RHT, which is
amenable to fast GPU implementation. In fact, one of the CUDA sample
kernels is the RHT. The RHT performs the multiplication <span
class="math inline">\(x \in \mathbb{R}^n \to \mathbb{H}Sx\)</span>,
where <span class="math inline">\(\mathbb{H}\)</span> is a <span
class="math inline">\(n \times n\)</span> Hadamard matrix (scaled by a
normalization factor) and <span class="math inline">\(S\)</span> is a
<span class="math inline">\(n\)</span> dimensional random sign vector.
The RHT concentrates the entries of <span
class="math inline">\(x\)</span> and thus results in incoherent matrices
through an <a
href="http://www.cs.cmu.edu/afs/cs/user/dwoodruf/www/teaching/15859-fall20/lecture_2.1.pdf">application
of the Azuma-Hoeffding inequality</a>. Note that the Hadamard transform
can be computed more efficiently than a matrix multiplication via the
fast Walsh-Hadamard transform, which we employ directly for powers of 2.
To handle non power-of-two values of <span
class="math inline">\(n\)</span>, we perform the following
algorithm:</p>
<ol type="1">
<li>Let <span class="math inline">\(p\)</span> be the remaining
dimension and reshape <span class="math inline">\(Sx\)</span> into a
<span class="math inline">\(n/p \times p\)</span> matrix.</li>
<li>Perform the fast Walsh-Hadamard transform on <span
class="math inline">\(Sx\)</span> associated with dimension <span
class="math inline">\(n/p\)</span>.</li>
<li>Let <span class="math inline">\(\mathbb{H}&#39;\)</span> be a <span
class="math inline">\(p \times p\)</span> scaled Hadamard matrix. Apply
this Hadamard transform to <span class="math inline">\(Sx\)</span> on
the right, and reshape back.</li>
</ol>
<p>The only consequence of performing RHT is needing to store two sign
vectors per layer: <span class="math inline">\(S_U\)</span> and <span
class="math inline">\(S_V\)</span>. Since large language models have
large weight and Hessian matrices, this only increases the number of
bits per weight in practice by less than 0.01, or a negligible
amount.</p>
<h3 id="lattice-codebooks">Lattice Codebooks</h3>
<p>Incoherence processed weights are approximately Gaussian-distributed,
meaning that they are symmetric and “round.” To take advantage of this
“roundness,” we can use <span class="math inline">\(n\)</span>
dimensional codebooks that quantize <span
class="math inline">\(n\)</span> weights at once. Specifically, to
quantize <span class="math inline">\(x \in \mathbb{R}^n\)</span> to a
<span class="math inline">\(n\)</span> dimensional codebook <span
class="math inline">\(C \in \mathbb{R}^{m \times n}\)</span>, we round
<span class="math inline">\(x\)</span> to its nearest distance-wise
entry in <span class="math inline">\(C\)</span>. This requires <span
class="math inline">\(\log_2m\)</span> bits to represent which index in
<span class="math inline">\(C\)</span> to store, and results in <span
class="math inline">\(k = \frac{\log_2m}{n}\)</span> bits per
weight.</p>
<p>Increasing <span class="math inline">\(n\)</span> results in a
“rounder” codebook that reduces quantization error. However, note that
the number of bits per weight is determined by <em>both</em> the number
of entries in <span class="math inline">\(C\)</span> (m) as well as the
dimension of <span class="math inline">\(C\)</span> (n). To maintain a
set number of bits per weight, a linear increase in <span
class="math inline">\(n\)</span> requires an exponential increase in
<span class="math inline">\(m\)</span>. For example, a naively designed
16-dimensional codebook requires <span
class="math inline">\(2^{32}\)</span> entries to achieve 2 bits per
weight, but performing lookups into a size <span
class="math inline">\(2^{32}\)</span> codebook is intractable. Thus, it
is important to design codebooks that both have relatively large <span
class="math inline">\(n\)</span> while being compressible so the actual
lookup happens with less than <span
class="math inline">\(2^{nk}\)</span> entries.</p>
<p>Geometric lattices are suitable bases for such codebooks as most
lattices have inherent symmetries and certain lattices achieve optimal
bin packing densities. For example, our E8P codebook based on the <span
class="math inline">\(E_8\)</span> lattice has <span
class="math inline">\(2^{16}\)</span> entries but only requires looking
up into a size <span class="math inline">\(2^8\)</span> codebook due to
symmetries inherent to the <span class="math inline">\(E_8\)</span>
lattice itself – more on this later. In QuIP#, we present the E8P
codebook based on the 8-dimensional <span
class="math inline">\(E_8\)</span> lattice. This lattice achieves the 8
dimensional kissing number, or the maximum number of unit balls touching
a central unit ball in 8 dimensions. Interestingly, Maryna Viazovska
recently won the Fields Medal in 2022 “for the proof that the <span
class="math inline">\(E_8\)</span> lattice provides the densest packing
of identical spheres in 8 dimensions.”</p>
<figure>
<img src="img/kissing2d.png"
alt="The 2D kissing number is 6, which is achieved by this packing configuration. Image from Wikipedia." />
<figcaption aria-hidden="true">The 2D kissing number is 6, which is
achieved by this packing configuration. Image from
Wikipedia.</figcaption>
</figure>
<h4 id="e8p-codebook">E8P Codebook</h4>
<p>Our E8P codebook is a version of the <span
class="math inline">\(E_8\)</span> lattice intersected with a ball,
padded (hence the P in E8P) to reach <span
class="math inline">\(2^{16}\)</span> entries. This results in <span
class="math inline">\(k = 16/8 = 2\)</span> bits per weight. The <span
class="math inline">\(E_8\)</span> lattice is composed of 8 dimensional
all-integer or all-half integer vectors whose sum is an even number. In
set-builder notation, <span class="math display">\[E_8 = \left\{x \mid x
\in \left(\mathbb{Z}^8 \cup \left(\mathbb{Z}+\frac{1}{2}\right)^8\right)
\land \sum_i x_i \equiv 0 \pmod 2\right\}.\]</span> Note that <span
class="math inline">\(E_8 + \frac{1}{4}\)</span> has the same packing
density of <span class="math inline">\(E_8\)</span> and is equivalent to
<span class="math inline">\(D_8 + \frac{1}{2} \pm \frac{1}{4}\)</span>,
where <span class="math inline">\(D_8\)</span> is the set of 8
dimensional all-integer vectors with even sum. Denote <span
class="math inline">\(D_8 + \frac{1}{2}\)</span> as <span
class="math inline">\(\hat{D_8}\)</span>; all elements in <span
class="math inline">\(\hat{D_8}\)</span> also have even sum parity.</p>
<p>Now, note that if we flip an even number of signs of an element in
<span class="math inline">\(\hat{D_8}\)</span>, we get another element
in <span class="math inline">\(\hat{D_8}\)</span>, whereas flipping an
odd number of signs results in something not in <span
class="math inline">\(\hat{D_8}\)</span>. This is due to <span
class="math inline">\(\hat{D_8}\)</span> being a half integer grid;
flipping a single half integer results in changing the sum parity. Since
<span class="math inline">\(\hat{D_8}\)</span> has 8 dimensions, there
are <span class="math inline">\(2^8/2 = 128\)</span> possible “even sign
flip” patterns to stay within <span
class="math inline">\(\hat{D_8}\)</span>. Conversely, there are also 128
“odd sign flip” patterns.</p>
<p>If we start from some “source codebook” <span
class="math inline">\(S\)</span> that is a subset of <span
class="math inline">\(|\hat{D_8}|\)</span>, where <span
class="math inline">\(|\cdot|\)</span> denotes the elementwise absolute
value, we can use 128 odd or even sign flips to generate a subset of
<span class="math inline">\(\hat{D_8}\)</span>. Each entry in <span
class="math inline">\(S\)</span> is either an odd or even number of
flips away from an entry in <span
class="math inline">\(\hat{D_8}\)</span>, but not both. Thus, given an
entry <span class="math inline">\(s \in S\)</span> and 7 out of the 8
sign flips, we can infer the last one from the parity of the 7 sign
flips and <span class="math inline">\(s\)</span>. This lets us use the
following bit pattern to store a 16-bit codeword in <span
class="math inline">\(E_8 + \frac{1}{4}\)</span>: 8 bits for the entry
index in <span class="math inline">\(S\)</span>, 7 bits for the sign
flips of the right 7 dimensions of the entry in <span
class="math inline">\(S\)</span>, and 1 bit to add or subtract <span
class="math inline">\(\frac{1}{4}\)</span>.</p>
<p>For example, if we had the codeword <code>0001010110010111</code>,
the first 8 bits <code>00010101</code> = 21 would indicate that we start
with the 21st entry in <span class="math inline">\(S\)</span>. In this
example, let this be the vector</p>
<p><span class="math display">\[\left\{\frac{1}{2}, \frac{1}{2},
\frac{1}{2}, \frac{3}{2}, \frac{1}{2}, \frac{1}{2}, \frac{1}{2},
\frac{1}{2}\right\},\]</span></p>
<p>which is not in <span class="math inline">\(\hat{D_8}\)</span>. Thus,
<span class="math inline">\(s\)</span> requires an odd number of sign
flips to get into <span class="math inline">\(\hat{D_8}\)</span>. Then,
the next 7 bits <code>1001011</code> would indicate that we need to
negate the 1st, 2nd, 4th, and 7th from right bits. Since we need an odd
number of sign flips, the 8th from right bit is also a sign flip. The
sign-decoded vector is then</p>
<p><span class="math display">\[\left\{-\frac{1}{2}, -\frac{1}{2},
\frac{1}{2}, \frac{3}{2}, -\frac{1}{2}, \frac{1}{2}, -\frac{1}{2},
-\frac{1}{2}\right\},\]</span></p>
<p>which we can verify is in <span class="math inline">\(E_8\)</span>.
Finally, the last bit <code>1</code> indicates that we need to add <span
class="math inline">\(\frac{1}{4}\)</span>, so the final decoded vector
is</p>
<p><span class="math display">\[\left\{-\frac{1}{4}, -\frac{3}{4},
\frac{3}{4}, \frac{7}{4}, -\frac{1}{4}, \frac{3}{4}, -\frac{1}{4},
-\frac{1}{4}\right\},\]</span></p>
<p>which is in <span class="math inline">\(E_8 + \frac{1}{4}\)</span> as
desired.</p>
<p>Putting this all together, this lets us decode a size <span
class="math inline">\(2^{16}\)</span> codebook by looking up into only a
size <span class="math inline">\(2^8\)</span> codebook (namely <span
class="math inline">\(S\)</span>) and performing some operations. On
hardware, this means that we can store the smaller <span
class="math inline">\(2^8\)</span> codebook in local caches, avoiding
performance killing memory accesses that the larger <span
class="math inline">\(2^{16}\)</span> codebook would require. The
question remains then of how to choose <span
class="math inline">\(S\)</span>. In our implementation, we set <span
class="math inline">\(S\)</span> to be the 227 elements of <span
class="math inline">\(|\hat{D_8}|\)</span> with norm <span
class="math inline">\(\le \sqrt{10}\)</span> plus 29 elements from <span
class="math inline">\(|\hat{D_8}|\)</span> that have norm <span
class="math inline">\(\sqrt{12}\)</span>. The exact elements chosen can
be found in our code.</p>
<h4 id="codebook-errors">Codebook Errors</h4>
<p>To show the optimality of our lattice codebooks, we plotted the
minimum achievable elementwise MSE of quantizing a <span
class="math inline">\(n\)</span>-dimensional multivariate Gaussian to
various <span class="math inline">\(k\)</span> bit codebooks. To create
each codebook, we intersected a ball with the base lattice and increased
the radius of the ball to get more bits. The half integer codebooks are
formed from the <span class="math inline">\(n\)</span>-dimensional half
integer grids.</p>
<p>Specifically, each point in the graph below plots <span
class="math inline">\((k, y)\)</span> where</p>
<p><span class="math display">\[y = \min_{s \in \mathbb{R}^+}
\frac{1}{n}\left\|\mbox{quantize}\left(\frac{\mathcal{N}(0_n, I_n)}{s},
\mbox{codebook}\right)*s - \mathcal{N}(0_n, I_n)\right\|^2\]</span></p>
<figure>
<img src="img/lattice_err.png" title="Lattice Errors"
alt="Lowest element-wise mean squared error (MSE) achievable for quantizing a multivariate Gaussian to various codebooks. The E_8 lattice achieves the densest unit-sphere packing in 8 dimensions and our derivative codebooks have the lowest MSE." />
<figcaption aria-hidden="true">Lowest element-wise mean squared error
(MSE) achievable for quantizing a multivariate Gaussian to various
codebooks. The <span class="math inline">\(E_8\)</span> lattice achieves
the <a href="https://en.wikipedia.org/wiki/Kissing_number">densest
unit-sphere packing in 8 dimensions</a> and our derivative codebooks
have the lowest MSE.</figcaption>
</figure>
<p>The <span class="math inline">\(E_8\)</span>-based codebooks achieves
lower MSEs than all other codebooks, including those based on the <span
class="math inline">\(D_4\)</span> lattice that achieves the 4
dimensional kissing number. This figure also shows the importance of
having a large number of columns <span class="math inline">\(n\)</span>.
Increasing the number of columns decreases the error for the half
integer grid, as the resulting codebook is more “round.” Since the E8P
codebook is actually the union of two shifted codebooks, each of which
is a ball intersected with <span
class="math inline">\(\hat{D_8}\)</span>, it is not perfectly round.
This is reflected in the MSE plot, where it sits slightly above the
<span class="math inline">\(E_8\)</span> line. However, there does not
exist a <span class="math inline">\(E_8\)</span> codebook with exactly 2
bits, so E8P is still practically superior.</p>
<h3 id="results">Results</h3>
<p>Here, we present quantization results using QuIP# on Llama 1 and 2
models. All models were quantized using the Hadamard transform for
incoherence processing and a weight scale factor of roughly 0.9 times
the optimal scale for a multivariate Gaussian to compensate for
inter-layer interactions. Furthermore, all Llama 2 models were evaluated
using a context lenth of 4096 and all Llama 1 models were evaluated with
context length 2048; these numbers match the context length the models
were trained with. These and other models can be found in our <a
href="https://huggingface.co/relaxml">Hugging Face repository</a>.</p>
<p>The table below contains results for all Llama 1 and 2 models when
quantized to 2 bits using the E8P codebook. QuIP# achieves excellent
performance across all model sizes on both language modeling and zero
shot tasks. Furthermore, on the zero-shot tasks (ArcC, ArcE, BoolQ,
PiQA, WinoGrande), QuIP# models achieve near-native performance with
minimal degradation. Additional results are available <a
href="https://docs.google.com/spreadsheets/d/18woLrIBdVGUr9CuFDbK9pl_6QzEBl09hfnoe4Qkg7Hg/edit?usp=sharing">here</a>.</p>
<div style="margin-left: -6%;
margin-right: auto;
width: 112%;">
<table>
<caption>QuIP# results across all Llama 1 and 2 models. QuIP# achieves
near-native performance at 2 bits on language modeling (C4, Wiki) and
zero shot (ArcC, ArcE, BoolQ, PiQA, WinoGrande) tasks.</caption>
<colgroup>
<col style="width: 6%" />
<col style="width: 6%" />
<col style="width: 10%" />
<col style="width: 11%" />
<col style="width: 10%" />
<col style="width: 10%" />
<col style="width: 12%" />
<col style="width: 10%" />
<col style="width: 20%" />
</colgroup>
<thead>
<tr class="header">
<th style="text-align: center;">Model</th>
<th style="text-align: center;">Method</th>
<th style="text-align: center;">C4 <span
class="math inline">\(\downarrow\)</span></th>
<th style="text-align: center;">Wiki <span
class="math inline">\(\downarrow\)</span></th>
<th style="text-align: center;">ArcC <span
class="math inline">\(\uparrow\)</span></th>
<th style="text-align: center;">ArcE <span
class="math inline">\(\uparrow\)</span></th>
<th style="text-align: center;">BoolQ <span
class="math inline">\(\uparrow\)</span></th>
<th style="text-align: center;">PiQA <span
class="math inline">\(\uparrow\)</span></th>
<th style="text-align: center;">WinoGrande <span
class="math inline">\(\uparrow\)</span></th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td style="text-align: center;">2-70B</td>
<td style="text-align: center;">fp16</td>
<td style="text-align: center;">5.533</td>
<td style="text-align: center;">3.120</td>
<td style="text-align: center;">0.480</td>
<td style="text-align: center;">0.597</td>
<td style="text-align: center;">0.766</td>
<td style="text-align: center;">0.809</td>
<td style="text-align: center;">0.768</td>
</tr>
<tr class="even">
<td style="text-align: center;">2-70B</td>
<td style="text-align: center;">QuIP#</td>
<td style="text-align: center;">6.529</td>
<td style="text-align: center;">4.158</td>
<td style="text-align: center;">0.472</td>
<td style="text-align: center;">0.595</td>
<td style="text-align: center;">0.791</td>
<td style="text-align: center;">0.786</td>
<td style="text-align: center;">0.742</td>
</tr>
<tr class="odd">
<td style="text-align: center;">2-13B</td>
<td style="text-align: center;">fp16</td>
<td style="text-align: center;">6.520</td>
<td style="text-align: center;">4.574</td>
<td style="text-align: center;">0.443</td>
<td style="text-align: center;">0.580</td>
<td style="text-align: center;">0.690</td>
<td style="text-align: center;">0.790</td>
<td style="text-align: center;">0.699</td>
</tr>
<tr class="even">
<td style="text-align: center;">2-13B</td>
<td style="text-align: center;">QuIP#</td>
<td style="text-align: center;">8.755</td>
<td style="text-align: center;">6.058</td>
<td style="text-align: center;">0.371</td>
<td style="text-align: center;">0.501</td>
<td style="text-align: center;">0.665</td>
<td style="text-align: center;">0.757</td>
<td style="text-align: center;">0.636</td>
</tr>
<tr class="odd">
<td style="text-align: center;">2-7B</td>
<td style="text-align: center;">fp16</td>
<td style="text-align: center;">7.036</td>
<td style="text-align: center;">5.116</td>
<td style="text-align: center;">0.406</td>
<td style="text-align: center;">0.535</td>
<td style="text-align: center;">0.710</td>
<td style="text-align: center;">0.769</td>
<td style="text-align: center;">0.670</td>
</tr>
<tr class="even">
<td style="text-align: center;">2-7B</td>
<td style="text-align: center;">QuIP#</td>
<td style="text-align: center;">12.062</td>
<td style="text-align: center;">8.224</td>
<td style="text-align: center;">0.325</td>
<td style="text-align: center;">0.428</td>
<td style="text-align: center;">0.623</td>
<td style="text-align: center;">0.712</td>
<td style="text-align: center;">0.624</td>
</tr>
<tr class="odd">
<td style="text-align: center;">1-65b</td>
<td style="text-align: center;">fp16</td>
<td style="text-align: center;">5.811</td>
<td style="text-align: center;">3.532</td>
<td style="text-align: center;">0.463</td>
<td style="text-align: center;">0.588</td>
<td style="text-align: center;">0.823</td>
<td style="text-align: center;">0.809</td>
<td style="text-align: center;">0.771</td>
</tr>
<tr class="even">
<td style="text-align: center;">1-65b</td>
<td style="text-align: center;">QuIP#</td>
<td style="text-align: center;">6.744</td>
<td style="text-align: center;">4.566</td>
<td style="text-align: center;">0.436</td>
<td style="text-align: center;">0.569</td>
<td style="text-align: center;">0.817</td>
<td style="text-align: center;">0.805</td>
<td style="text-align: center;">0.736</td>
</tr>
<tr class="odd">
<td style="text-align: center;">1-30B</td>
<td style="text-align: center;">fp16</td>
<td style="text-align: center;">6.130</td>
<td style="text-align: center;">4.101</td>
<td style="text-align: center;">0.453</td>
<td style="text-align: center;">0.590</td>
<td style="text-align: center;">0.684</td>
<td style="text-align: center;">0.801</td>
<td style="text-align: center;">0.728</td>
</tr>
<tr class="even">
<td style="text-align: center;">1-30B</td>
<td style="text-align: center;">QuIP#</td>
<td style="text-align: center;">7.471</td>
<td style="text-align: center;">5.317</td>
<td style="text-align: center;">0.429</td>
<td style="text-align: center;">0.545</td>
<td style="text-align: center;">0.669</td>
<td style="text-align: center;">0.779</td>
<td style="text-align: center;">0.718</td>
</tr>
<tr class="odd">
<td style="text-align: center;">1-13B</td>
<td style="text-align: center;">fp16</td>
<td style="text-align: center;">6.798</td>
<td style="text-align: center;">5.091</td>
<td style="text-align: center;">0.444</td>
<td style="text-align: center;">0.599</td>
<td style="text-align: center;">0.684</td>
<td style="text-align: center;">0.792</td>
<td style="text-align: center;">0.701</td>
</tr>
<tr class="even">
<td style="text-align: center;">1-13B</td>
<td style="text-align: center;">QuIP#</td>
<td style="text-align: center;">8.425</td>
<td style="text-align: center;">6.381</td>
<td style="text-align: center;">0.387</td>
<td style="text-align: center;">0.536</td>
<td style="text-align: center;">0.647</td>
<td style="text-align: center;">0.750</td>
<td style="text-align: center;">0.669</td>
</tr>
<tr class="odd">
<td style="text-align: center;">1-7B</td>
<td style="text-align: center;">fp16</td>
<td style="text-align: center;">7.343</td>
<td style="text-align: center;">5.677</td>
<td style="text-align: center;">0.415</td>
<td style="text-align: center;">0.525</td>
<td style="text-align: center;">0.731</td>
<td style="text-align: center;">0.774</td>
<td style="text-align: center;">0.670</td>
</tr>
<tr class="even">
<td style="text-align: center;">1-7B</td>
<td style="text-align: center;">QuIP#</td>
<td style="text-align: center;">10.970</td>
<td style="text-align: center;">8.286</td>
<td style="text-align: center;">0.352</td>
<td style="text-align: center;">0.464</td>
<td style="text-align: center;">0.647</td>
<td style="text-align: center;">0.720</td>
<td style="text-align: center;">0.624</td>
</tr>
</tbody>
</table>
</div>
</body>
</html>