lewtun HF staff commited on
Commit
93a9871
·
1 Parent(s): cb97215

Moar fixes

Browse files
Files changed (1) hide show
  1. app/src/index.html +6 -4
app/src/index.html CHANGED
@@ -47,7 +47,7 @@
47
  <d-contents>
48
  </d-contents>
49
 
50
- <p>Over the last few years, the scaling of <em><strong>train-time</strong></em> <strong>compute</strong><strong> </strong>has dominated the progress of large language models (LLMs). Although this paradigm has proven to be remarkably effective, the resources needed to pretrain ever larger models are becoming prohibitively expensive, with <a href="https://youtu.be/WXhikNA5PIc?feature=shared">billion-dollar clusters</a> already on the horizon. This trend has sparked significant interest in a complementary approach: <em><strong>test-time compute scaling</strong></em>. Rather than relying on ever-larger pretraining budgets, test-time methods use dynamic inference strategies that allow models to “think longer” on harder problems. A prominent example is <a href="https://openai.com/index/learning-to-reason-with-llms/">OpenAI’s o1 model</a>, which shows consistent improvement on difficult math problems as one increases the amount of test-time compute:</p>
51
 
52
  <figure id="1581384e-bcac-805f-8c2b-dff4509f45cb" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/compute.png.webp"><img style="width:672px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/compute.png.webp"/></a></figure>
53
 
@@ -137,9 +137,11 @@ def get_canonical_form(expression: str) -&gt; str:
137
  # Return the first occurring group in case of a tie
138
  return canonical_to_original[canonical_form]</code></pre>
139
 
140
- <p id="15d1384e-bcac-804e-a99c-fe5e83313a3d" class="">This approach was significantly faster than checking each pair of solutions independently for equality.</p></div></details><p id="15b1384e-bcac-80f7-83e8-e1d6b360faa4" class="">Here’s how majority voting performs when applied to the generations from Llama 3.2 1B Instruct:</p><figure id="15b1384e-bcac-8072-9987-d80031b97793" class="image"><a href="Scaling%20test-time%20compute%20with%20open%20models%201531384ebcac800b9d73fca3503eb783/methods-maj.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj.png"/></a></figure><p id="15b1384e-bcac-8020-8688-fe1713e92c2b" class="">The results show that majority voting yields a significant improvement over the greedy decoding baseline, but its gains start to plateau after approximately <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>64</mn></mrow><annotation encoding="application/x-tex">N=64</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">64</span></span></span></span></span><span></span></span> generations. This limitation arises because majority voting struggles with problems that require nuanced reasoning or tasks where errors are consistent across generations. If you’re also wondering why the majority voting accuracy is worse than the 0-shot CoT baseline for N=1 and 2, that’s because we sample at T=0.8, which makes it less likely we produce the correct answer among a handful of candidates.</p><p id="15b1384e-bcac-8075-8fef-f26f0b8e5559" class="">Building on the limitations of majority voting, let’s see how incorporating a reward model can enhance performance.</p>
 
 
141
 
142
- <h2 id="1591384e-bcac-8098-9db5-f76c9ce00e7a" class="">Beyond majority: Best-of-N</h2><p id="15b1384e-bcac-8019-9b5c-d11bae74628d" class="">Best-of-N is a simple, but effective extension to majority voting that uses a reward model to determine the most plausible answer. This method comes in two main variants:</p><ul id="15b1384e-bcac-80b4-aae4-d5e98e29debf" class="bulleted-list"><li style="list-style-type:disc"><strong>Vanilla Best-of-N:</strong> Generate <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> independent responses and select the one with the <em>highest RM reward</em> as the final answer. This ensures that the most confident individual response is chosen, but it doesn’t account for consistency across answers.</li></ul><ul id="15b1384e-bcac-8035-a394-fbd954af1984" class="bulleted-list"><li style="list-style-type:disc"><strong>Weighted Best-of-N:</strong> Aggregate scores across all identical responses and select the answer with the <em>highest total reward</em>. This approach prioritises high-quality answers by boosting their scores through repeated occurrences. Mathematically, the weighting across answers <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>a</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">a_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span></span></span> is performed as follows:<figure id="15d1384e-bcac-80e5-8d68-fe7bad033482" class="equation"><style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><div class="equation-container"><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msub><mi>a</mi><mrow><mi mathvariant="normal">w</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">i</mi><mi mathvariant="normal">g</mi><mi mathvariant="normal">h</mi><mi mathvariant="normal">t</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">d</mi></mrow></msub><mo>=</mo><mi>arg</mi><mo>⁡</mo><munder><mrow><mi>max</mi><mo>⁡</mo></mrow><mi>a</mi></munder><munderover><mo>∑</mo><mrow><mi>i</mi><mo>=</mo><mn>1</mn></mrow><mi>N</mi></munderover><mi mathvariant="double-struck">I</mi><mo stretchy="false">(</mo><msub><mi>a</mi><mi>i</mi></msub><mo>=</mo><mi>a</mi><mo stretchy="false">)</mo><mo>⋅</mo><mrow><mi mathvariant="normal">R</mi><mi mathvariant="normal">M</mi></mrow><mo stretchy="false">(</mo><mi>p</mi><mo separator="true">,</mo><msub><mi>s</mi><mi>i</mi></msub><mo stretchy="false">)</mo><mtext> </mtext><mo separator="true">,</mo></mrow><annotation encoding="application/x-tex">a_\mathrm{weighted} = \arg\max_{a} \sum_{i=1}^{N} \mathbb{I}(a_i = a) \cdot \mathrm{RM}(p, s_i) \,,</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7167em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathnormal">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathrm mtight">weighted</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:3.106em;vertical-align:-1.2777em;"></span><span class="mop">ar<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.4306em;"><span style="top:-2.4em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">a</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop">max</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.7em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.8283em;"><span style="top:-1.8723em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mrel mtight">=</span><span class="mord mtight">1</span></span></span></span><span style="top:-3.05em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span><span style="top:-4.3em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.10903em;">N</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.2777em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathbb">I</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathrm">RM</span></span><span class="mopen">(</span><span class="mord mathnormal">p</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mpunct">,</span></span></span></span></span></div></figure><p id="15d1384e-bcac-8083-8f2a-d5701df84dcd" class="">where <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mrow><mi mathvariant="normal">R</mi><mi mathvariant="normal">M</mi></mrow><mo stretchy="false">(</mo><mi>p</mi><mo separator="true">,</mo><msub><mi>s</mi><mi>i</mi></msub><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mathrm{RM}(p, s_i)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathrm">RM</span></span><span class="mopen">(</span><span class="mord mathnormal">p</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span><span></span></span> is the reward model score of the <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>i</mi></mrow><annotation encoding="application/x-tex">i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6595em;"></span><span class="mord mathnormal">i</span></span></span></span></span><span></span></span>-th solution solution <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>s</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">s_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span></span></span> to problem <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi></mrow><annotation encoding="application/x-tex">p</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">p</span></span></span></span></span><span></span></span>.</p></li></ul><p id="15d1384e-bcac-8012-8282-c0ed1215a611" class="">Typically, one usually uses an outcome reward model (ORM) to get a single, solution-level score. But to allow for fair comparison with the other search strategies discussed later, we will use the same PRM to score the solutions from Best-of-N. As illustrated below, PRMs produce a <em>cumulative</em> <em>sequence of step-level scores</em> per solution, so we need to perform a reduction over the steps to obtain a single solution-level score: </p><figure id="15d1384e-bcac-80d6-815f-c7d87fe313a6" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/prm-reductions.png"><img style="width:700px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/prm-reductions.png"/></a></figure><p id="15d1384e-bcac-80e7-8d1a-e0aab286f9f4" class="">In the literature, the most common reductions are the following:</p><ul id="15b1384e-bcac-80e4-92b4-e2bc90a9130a" class="bulleted-list"><li style="list-style-type:disc"><strong>Min: </strong>use the minimum score across all steps.</li></ul><ul id="15b1384e-bcac-8073-b4dc-fbfcfc0567bc" class="bulleted-list"><li style="list-style-type:disc"><strong>Prod: </strong>use the product of step-level scores.</li></ul><ul id="15b1384e-bcac-80ed-8cc5-fa6e2ce330fb" class="bulleted-list"><li style="list-style-type:disc"><strong>Last: </strong>use the final score in the steps. This score contains the cumulative information from all prior steps, so treats the PRM effectively as an ORM that is able to score partial solutions.</li></ul><p id="15b1384e-bcac-80ad-96d1-d313ae3e1954" class="">We experimented with each reduction and found—like DeepMind—that <em><strong>“last” performs best for our choice of task and PRM</strong></em>. We use this aggregation throughout all of our experiments and you can expand the detail below to see how we implemented it, along with the weighting procedure discussed above.</p>
143
 
144
  <p id="15d1384e-bcac-809a-8aa8-c52ca7301b52" class="">Here’s the results one gets from applying both variants of Best-of-N:</p><figure id="15b1384e-bcac-808d-857e-d492683a4a91" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon.png"/></a></figure><p id="15b1384e-bcac-8001-9320-ff788bab0c52" class="">The results reveal a clear advantage: <strong>weighted Best-of-N</strong> consistently outperforms vanilla Best-of-N, especially with larger generation budgets. Its ability to aggregate scores across identical responses ensures that even less frequent but higher-quality answers are effectively prioritized.</p><p id="15b1384e-bcac-808a-b3ff-ee08c05a20af" class="">However, despite these improvements, we’re still falling short of the performance achieved by the Llama 8B model and the Best-of-N approach is starting to plateau at <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>256</mn></mrow><annotation encoding="application/x-tex">N=256</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">256</span></span></span></span></span><span></span></span> generations. Can we push the boundaries further by supervising the search process step-by-step? Let’s find out 🚀!</p>
145
 
@@ -148,7 +150,7 @@ def get_canonical_form(expression: str) -&gt; str:
148
 
149
  <h2 id="1591384e-bcac-80d2-8234-fe0e9a4df59d" class="">DVTS: boosting performance with diversity</h2><p id="1591384e-bcac-8044-b7c5-cf39e4aed683" class="">As we saw above beam search gives strong performance over Best-of-N, but tends to underperform on simpler problems and at large test-time compute budgets. To address this, we developed an extension we call Diverse Verifier Tree Search (DVTS) that is designed to maximise diversity at large <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span>.</p><p id="15a1384e-bcac-80ff-a97b-c7ccd88958e4" class="">DVTS works in a similar fashion as beam search, with the following modifications:</p><ol type="1" id="15d1384e-bcac-806c-8004-e054a98d98ef" class="numbered-list" start="1"><li>For a given <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> and <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span>, expand the initial set of beams into <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mi mathvariant="normal">/</mi><mi>M</mi></mrow><annotation encoding="application/x-tex">N/M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mord">/</span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> <em>independent</em> subtrees.</li></ol><ol type="1" id="15d1384e-bcac-8081-8508-feb06a13469b" class="numbered-list" start="2"><li>For each subtree, select the step with the highest PRM score.</li></ol><ol type="1" id="15d1384e-bcac-806a-976f-ec9596cd9532" class="numbered-list" start="3"><li>Generate <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> new steps from the nodes selected in step (2) and select the step with the highest PRM score.</li></ol><ol type="1" id="15d1384e-bcac-808e-aa2b-f391ec426953" class="numbered-list" start="4"><li>Repeat step (3) until the EOS token or maximum tree depth is reached.</li></ol><p id="15d1384e-bcac-8087-b916-d9603de035dd" class="">Here’s the results from applying DVTS to Llama 1B:</p><figure id="15b1384e-bcac-801c-a1e7-d4e544826da3" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-all.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-all.png"/></a></figure><p id="15b1384e-bcac-80e1-bc9b-dbdb5738b9f1" class="">As we can see, DVTS provides a complementary strategy to beam search: at small <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> beam search is more effective at finding correct solutions, but at large <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> the diversity of DVTS candidates kicks in and we get better performance. </p><p id="15d1384e-bcac-80a7-8379-dca3c329c433" class="">We can also see this manifested in the problem difficulty breakdown, where DVTS enhances performance on the easy / medium problems at large N, while beam search is best at small N across model problem difficulties:</p><figure id="15b1384e-bcac-807a-8dca-f322077cc616" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-all.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-all.png"/></a></figure>
150
 
151
- <h2 id="1591384e-bcac-806b-9dd0-c80a250c7754" class="">The best of all worlds: compute-optimal scaling</h2><p id="1591384e-bcac-80e0-93e6-ceaacc131142" class="">Armed with various search strategies, a natural question is which one is best? In the DeepMind paper, they proposed a <em><strong>compute-optimal</strong></em> <em><strong>scaling strategy</strong></em> where one selects the search method and hyperparameters <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>θ</mi></mrow><annotation encoding="application/x-tex">\theta</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span></span></span></span></span><span></span></span> that achieves the <em><strong>best performance for a given compute budget </strong></em><em><style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span></em><em><strong>:</strong></em></p><figure id="15e1384e-bcac-8054-8afa-ca441a776a05" class="equation"><style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><div class="equation-container"><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msubsup><mi>θ</mi><mrow><mi>q</mi><mo separator="true">,</mo><msup><mi>a</mi><mo>∗</mo></msup><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">)</mo></mrow><mo>∗</mo></msubsup><mo stretchy="false">(</mo><mi>N</mi><mo stretchy="false">)</mo><mo>=</mo><mi><munder><mo><mi>arg</mi><mo>⁡</mo><mi>max</mi><mo>⁡</mo></mo><mi>θ</mi></munder></mi><mrow><mo fence="true">(</mo><msub><mi mathvariant="double-struck">E</mi><mrow><mi>y</mi><mo>∼</mo><mtext>Target</mtext><mo stretchy="false">(</mo><mi>θ</mi><mo separator="true">,</mo><mi>N</mi><mo separator="true">,</mo><mi>q</mi><mo stretchy="false">)</mo></mrow></msub><mrow><mo fence="true">[</mo><msub><mn mathvariant="double-struck">1</mn><mrow><mi>y</mi><mo>=</mo><msup><mi>y</mi><mo>∗</mo></msup><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">)</mo></mrow></msub><mo fence="true">]</mo></mrow><mo fence="true">)</mo></mrow><mo separator="true">,</mo></mrow><annotation encoding="application/x-tex">\theta_{q,a^*(q)}^*(N) = \underset{\theta}{\arg\max} \left( \mathbb{E}_{y \sim \text{Target}(\theta, N, q)} \left[ \mathbb{1}_{y = y^*(q)} \right] \right),</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.197em;vertical-align:-0.447em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7387em;"><span style="top:-2.428em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mpunct mtight">,</span><span class="mord mtight"><span class="mord mathnormal mtight">a</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6183em;"><span style="top:-2.786em;margin-right:0.0714em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mclose mtight">)</span></span></span></span><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.447em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.7965em;vertical-align:-0.9465em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.4306em;"><span style="top:-2.1535em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop"><span class="mop">ar<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">max</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.9465em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;"><span class="delimsizing size1">(</span></span><span class="mord"><span class="mord mathbb">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.5198em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mrel mtight">∼</span><span class="mord text mtight"><span class="mord mtight">Target</span></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight" style="margin-right:0.10903em;">N</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.3552em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;"><span class="delimsizing size1">[</span></span><span class="mord"><span class="mord">1</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.5198em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mrel mtight">=</span><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6183em;"><span style="top:-2.786em;margin-right:0.0714em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.3552em;"><span></span></span></span></span></span></span><span class="mclose delimcenter" style="top:0em;"><span class="delimsizing size1">]</span></span></span><span class="mclose delimcenter" style="top:0em;"><span class="delimsizing size1">)</span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mpunct">,</span></span></span></span></span></div></figure><p id="15e1384e-bcac-8011-bf10-cc868e25db7c" class="">where <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>y</mi><mo>∗</mo></msup><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">y^*(q)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6887em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="mclose">)</span></span></span></span></span><span></span></span> is the ground-truth for question <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>q</mi></mrow><annotation encoding="application/x-tex">q</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span></span></span></span></span><span></span></span> and <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msubsup><mi>θ</mi><mrow><mi>q</mi><mo separator="true">,</mo><msup><mi>a</mi><mo>∗</mo></msup><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">)</mo></mrow><mo>∗</mo></msubsup><mo stretchy="false">(</mo><mi>N</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\theta_{q,a^*(q)}^*(N)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.247em;vertical-align:-0.497em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6887em;"><span style="top:-2.378em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mpunct mtight">,</span><span class="mord mtight"><span class="mord mathnormal mtight">a</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6183em;"><span style="top:-2.786em;margin-right:0.0714em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mclose mtight">)</span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.497em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mclose">)</span></span></span></span></span><span></span></span> denotes the compute-optimal scaling strategy. Since computing <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msubsup><mi>θ</mi><mrow><mi>q</mi><mo separator="true">,</mo><msup><mi>a</mi><mo>∗</mo></msup><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">)</mo></mrow><mo>∗</mo></msubsup><mo stretchy="false">(</mo><mi>N</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\theta_{q,a^*(q)}^*(N)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.247em;vertical-align:-0.497em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6887em;"><span style="top:-2.378em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mpunct mtight">,</span><span class="mord mtight"><span class="mord mathnormal mtight">a</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6183em;"><span style="top:-2.786em;margin-right:0.0714em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mclose mtight">)</span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.497em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mclose">)</span></span></span></span></span><span></span></span> directly is somewhat tricky, DeepMind proposed an approximation based on the <em><strong>problem difficulty</strong></em>, i.e. allocate test-time compute according to which search strategy achieves best performance for a given difficulty level.</p><p id="15a1384e-bcac-80c9-a276-d5ea8974c543" class="">For example, on simpler problems and lower compute budgets, it is better to use strategies like Best-of-N, while on harder problems, beam search is the better choice. We can represent the compute-optimal strategy mathematically as follows: </p><p id="15e1384e-bcac-8060-817e-daf4b3b7e34d" class="">[ADD DETAILS]</p><p id="15a1384e-bcac-806f-92f0-f3ac4ddf7b19" class="">And voila, we now have our compute-optimal curve!</p><figure id="15b1384e-bcac-80b3-bc58-d20ba41d3950" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-opt.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-opt.png"/></a></figure>
152
 
153
  <h2 id="1591384e-bcac-809a-96d2-e928398d159a" class="">Scaling up to larger models</h2><p id="15a1384e-bcac-8078-86d7-f48c2146444e" class="">We also explored scaling up the compute-optimal recipe to Llama 3.2 3B Instruct to see at what point the benefits of the PRM fade in comparison to the policy’s own capacity. To our surprise, compute-optimal scaling works remarkably well, with the 3B model surpassing the performance of Llama 3.1 70B Instruct (22x it's size!):</p><figure id="15b1384e-bcac-80b3-bc58-d20ba41d3950" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-opt-3b.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-opt-3b.png"/></a></figure>
154
 
 
47
  <d-contents>
48
  </d-contents>
49
 
50
+ <p>Over the last few years, the scaling of <em><strong>train-time</strong></em> <strong>compute</strong><strong> </strong>has dominated the progress of large language models (LLMs). Although this paradigm has proven to be remarkably effective, the resources needed to pretrain ever larger models are becoming prohibitively expensive, with <a href="https://youtu.be/WXhikNA5PIc?feature=shared">billion-dollar clusters</a> already on the horizon. This trend has sparked significant interest in a complementary approach: <em><strong>test-time compute scaling</strong></em>. Rather than relying on ever-larger pretraining budgets, test-time methods use dynamic inference strategies that allow models to “think longer” on harder problems. A prominent example is <a href="https://openai.com/index/learning-to-reason-with-llms/">OpenAI’s o1 model</a>, which shows consistent improvement on difficult math problems as one increases the amount of test-time compute:</p>
51
 
52
  <figure id="1581384e-bcac-805f-8c2b-dff4509f45cb" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/compute.png.webp"><img style="width:672px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/compute.png.webp"/></a></figure>
53
 
 
137
  # Return the first occurring group in case of a tie
138
  return canonical_to_original[canonical_form]</code></pre>
139
 
140
+ <p id="15d1384e-bcac-804e-a99c-fe5e83313a3d" class="">This approach was significantly faster than checking each pair of solutions independently for equality.</p></div></details>
141
+ <br>
142
+ <p id="15b1384e-bcac-80f7-83e8-e1d6b360faa4" class="">Here’s how majority voting performs when applied to the generations from Llama 3.2 1B Instruct:</p><figure id="15b1384e-bcac-8072-9987-d80031b97793" class="image"><a href="Scaling%20test-time%20compute%20with%20open%20models%201531384ebcac800b9d73fca3503eb783/methods-maj.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj.png"/></a></figure><p id="15b1384e-bcac-8020-8688-fe1713e92c2b" class="">The results show that majority voting yields a significant improvement over the greedy decoding baseline, but its gains start to plateau after approximately <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>64</mn></mrow><annotation encoding="application/x-tex">N=64</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">64</span></span></span></span></span><span></span></span> generations. This limitation arises because majority voting struggles with problems that require nuanced reasoning or tasks where errors are consistent across generations. If you’re also wondering why the majority voting accuracy is worse than the 0-shot CoT baseline for N=1 and 2, that’s because we sample at T=0.8, which makes it less likely we produce the correct answer among a handful of candidates.</p><p id="15b1384e-bcac-8075-8fef-f26f0b8e5559" class="">Building on the limitations of majority voting, let’s see how incorporating a reward model can enhance performance.</p>
143
 
144
+ <h2 id="1591384e-bcac-8098-9db5-f76c9ce00e7a" class="">Beyond majority: Best-of-N</h2><p id="15b1384e-bcac-8019-9b5c-d11bae74628d" class="">Best-of-N is a simple, but effective extension to majority voting that uses a reward model to determine the most plausible answer. This method comes in two main variants:</p><ul id="15b1384e-bcac-80b4-aae4-d5e98e29debf" class="bulleted-list"><li style="list-style-type:disc"><strong>Vanilla Best-of-N:</strong> Generate <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> independent responses and select the one with the <em>highest RM reward</em> as the final answer. This ensures that the most confident individual response is chosen, but it doesn’t account for consistency across answers.</li></ul><ul id="15b1384e-bcac-8035-a394-fbd954af1984" class="bulleted-list"><li style="list-style-type:disc"><strong>Weighted Best-of-N:</strong> Aggregate scores across all identical responses and select the answer with the <em>highest total reward</em>. This approach prioritises high-quality answers by boosting their scores through repeated occurrences. Mathematically, the weighting across answers <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>a</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">a_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span></span></span> is performed as follows:<figure id="15d1384e-bcac-80e5-8d68-fe7bad033482" class="equation"><style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><div class="equation-container"><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msub><mi>a</mi><mrow><mi mathvariant="normal">w</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">i</mi><mi mathvariant="normal">g</mi><mi mathvariant="normal">h</mi><mi mathvariant="normal">t</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">d</mi></mrow></msub><mo>=</mo><mi>arg</mi><mo>⁡</mo><munder><mrow><mi>max</mi><mo>⁡</mo></mrow><mi>a</mi></munder><munderover><mo>∑</mo><mrow><mi>i</mi><mo>=</mo><mn>1</mn></mrow><mi>N</mi></munderover><mi mathvariant="double-struck">I</mi><mo stretchy="false">(</mo><msub><mi>a</mi><mi>i</mi></msub><mo>=</mo><mi>a</mi><mo stretchy="false">)</mo><mo>⋅</mo><mrow><mi mathvariant="normal">R</mi><mi mathvariant="normal">M</mi></mrow><mo stretchy="false">(</mo><mi>p</mi><mo separator="true">,</mo><msub><mi>s</mi><mi>i</mi></msub><mo stretchy="false">)</mo><mtext> </mtext><mo separator="true">,</mo></mrow><annotation encoding="application/x-tex">a_\mathrm{weighted} = \arg\max_{a} \sum_{i=1}^{N} \mathbb{I}(a_i = a) \cdot \mathrm{RM}(p, s_i) \,,</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7167em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathnormal">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathrm mtight">weighted</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:3.106em;vertical-align:-1.2777em;"></span><span class="mop">ar<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.4306em;"><span style="top:-2.4em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">a</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop">max</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.7em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.8283em;"><span style="top:-1.8723em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mrel mtight">=</span><span class="mord mtight">1</span></span></span></span><span style="top:-3.05em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span><span style="top:-4.3em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.10903em;">N</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.2777em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathbb">I</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathrm">RM</span></span><span class="mopen">(</span><span class="mord mathnormal">p</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mpunct">,</span></span></span></span></span></div></figure><p id="15d1384e-bcac-8083-8f2a-d5701df84dcd" class="">where <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mrow><mi mathvariant="normal">R</mi><mi mathvariant="normal">M</mi></mrow><mo stretchy="false">(</mo><mi>p</mi><mo separator="true">,</mo><msub><mi>s</mi><mi>i</mi></msub><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mathrm{RM}(p, s_i)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathrm">RM</span></span><span class="mopen">(</span><span class="mord mathnormal">p</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span><span></span></span> is the reward model score of the <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>i</mi></mrow><annotation encoding="application/x-tex">i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6595em;"></span><span class="mord mathnormal">i</span></span></span></span></span><span></span></span>-th solution solution <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>s</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">s_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span></span></span> to problem p.</p></li></ul><p id="15d1384e-bcac-8012-8282-c0ed1215a611" class="">Typically, one usually uses an outcome reward model (ORM) to get a single, solution-level score. But to allow for fair comparison with the other search strategies discussed later, we will use the same PRM to score the solutions from Best-of-N. As illustrated below, PRMs produce a <em>cumulative</em> <em>sequence of step-level scores</em> per solution, so we need to perform a reduction over the steps to obtain a single solution-level score: </p><figure id="15d1384e-bcac-80d6-815f-c7d87fe313a6" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/prm-reductions.png"><img style="width:700px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/prm-reductions.png"/></a></figure><p id="15d1384e-bcac-80e7-8d1a-e0aab286f9f4" class="">In the literature, the most common reductions are the following:</p><ul id="15b1384e-bcac-80e4-92b4-e2bc90a9130a" class="bulleted-list"><li style="list-style-type:disc"><strong>Min: </strong>use the minimum score across all steps.</li></ul><ul id="15b1384e-bcac-8073-b4dc-fbfcfc0567bc" class="bulleted-list"><li style="list-style-type:disc"><strong>Prod: </strong>use the product of step-level scores.</li></ul><ul id="15b1384e-bcac-80ed-8cc5-fa6e2ce330fb" class="bulleted-list"><li style="list-style-type:disc"><strong>Last: </strong>use the final score in the steps. This score contains the cumulative information from all prior steps, so treats the PRM effectively as an ORM that is able to score partial solutions.</li></ul><p id="15b1384e-bcac-80ad-96d1-d313ae3e1954" class="">We experimented with each reduction and found—like DeepMind—that <em><strong>“last” performs best for our choice of task and PRM</strong></em>. We use this aggregation throughout all of our experiments and you can expand the detail below to see how we implemented it, along with the weighting procedure discussed above.</p>
145
 
146
  <p id="15d1384e-bcac-809a-8aa8-c52ca7301b52" class="">Here’s the results one gets from applying both variants of Best-of-N:</p><figure id="15b1384e-bcac-808d-857e-d492683a4a91" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon.png"/></a></figure><p id="15b1384e-bcac-8001-9320-ff788bab0c52" class="">The results reveal a clear advantage: <strong>weighted Best-of-N</strong> consistently outperforms vanilla Best-of-N, especially with larger generation budgets. Its ability to aggregate scores across identical responses ensures that even less frequent but higher-quality answers are effectively prioritized.</p><p id="15b1384e-bcac-808a-b3ff-ee08c05a20af" class="">However, despite these improvements, we’re still falling short of the performance achieved by the Llama 8B model and the Best-of-N approach is starting to plateau at <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>256</mn></mrow><annotation encoding="application/x-tex">N=256</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">256</span></span></span></span></span><span></span></span> generations. Can we push the boundaries further by supervising the search process step-by-step? Let’s find out 🚀!</p>
147
 
 
150
 
151
  <h2 id="1591384e-bcac-80d2-8234-fe0e9a4df59d" class="">DVTS: boosting performance with diversity</h2><p id="1591384e-bcac-8044-b7c5-cf39e4aed683" class="">As we saw above beam search gives strong performance over Best-of-N, but tends to underperform on simpler problems and at large test-time compute budgets. To address this, we developed an extension we call Diverse Verifier Tree Search (DVTS) that is designed to maximise diversity at large <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span>.</p><p id="15a1384e-bcac-80ff-a97b-c7ccd88958e4" class="">DVTS works in a similar fashion as beam search, with the following modifications:</p><ol type="1" id="15d1384e-bcac-806c-8004-e054a98d98ef" class="numbered-list" start="1"><li>For a given <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> and <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span>, expand the initial set of beams into <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mi mathvariant="normal">/</mi><mi>M</mi></mrow><annotation encoding="application/x-tex">N/M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mord">/</span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> <em>independent</em> subtrees.</li></ol><ol type="1" id="15d1384e-bcac-8081-8508-feb06a13469b" class="numbered-list" start="2"><li>For each subtree, select the step with the highest PRM score.</li></ol><ol type="1" id="15d1384e-bcac-806a-976f-ec9596cd9532" class="numbered-list" start="3"><li>Generate <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> new steps from the nodes selected in step (2) and select the step with the highest PRM score.</li></ol><ol type="1" id="15d1384e-bcac-808e-aa2b-f391ec426953" class="numbered-list" start="4"><li>Repeat step (3) until the EOS token or maximum tree depth is reached.</li></ol><p id="15d1384e-bcac-8087-b916-d9603de035dd" class="">Here’s the results from applying DVTS to Llama 1B:</p><figure id="15b1384e-bcac-801c-a1e7-d4e544826da3" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-all.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-all.png"/></a></figure><p id="15b1384e-bcac-80e1-bc9b-dbdb5738b9f1" class="">As we can see, DVTS provides a complementary strategy to beam search: at small <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> beam search is more effective at finding correct solutions, but at large <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> the diversity of DVTS candidates kicks in and we get better performance. </p><p id="15d1384e-bcac-80a7-8379-dca3c329c433" class="">We can also see this manifested in the problem difficulty breakdown, where DVTS enhances performance on the easy / medium problems at large N, while beam search is best at small N across model problem difficulties:</p><figure id="15b1384e-bcac-807a-8dca-f322077cc616" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-all.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-all.png"/></a></figure>
152
 
153
+ <h2 id="1591384e-bcac-806b-9dd0-c80a250c7754" class="">The best of all worlds: compute-optimal scaling</h2><p id="1591384e-bcac-80e0-93e6-ceaacc131142" class="">Armed with various search strategies, a natural question is which one is best? In the DeepMind paper, they proposed a <em><strong>compute-optimal</strong></em> <em><strong>scaling strategy</strong></em> where one selects the search method and hyperparameters <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>θ</mi></mrow><annotation encoding="application/x-tex">\theta</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span></span></span></span></span><span></span></span> that achieves the <em><strong>best performance for a given compute budget </strong></em><em><style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span></em><em><strong>:</strong></em></p><figure id="15e1384e-bcac-8054-8afa-ca441a776a05" class="equation"><style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><div class="equation-container"><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msubsup><mi>θ</mi><mrow><mi>q</mi><mo separator="true">,</mo><msup><mi>a</mi><mo>∗</mo></msup><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">)</mo></mrow><mo>∗</mo></msubsup><mo stretchy="false">(</mo><mi>N</mi><mo stretchy="false">)</mo><mo>=</mo><mi><munder><mo><mi>arg</mi><mo>⁡</mo><mi>max</mi><mo>⁡</mo></mo><mi>θ</mi></munder></mi><mrow><mo fence="true">(</mo><msub><mi mathvariant="double-struck">E</mi><mrow><mi>y</mi><mo>∼</mo><mtext>Target</mtext><mo stretchy="false">(</mo><mi>θ</mi><mo separator="true">,</mo><mi>N</mi><mo separator="true">,</mo><mi>q</mi><mo stretchy="false">)</mo></mrow></msub><mrow><mo fence="true">[</mo><msub><mn mathvariant="double-struck">1</mn><mrow><mi>y</mi><mo>=</mo><msup><mi>y</mi><mo>∗</mo></msup><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">)</mo></mrow></msub><mo fence="true">]</mo></mrow><mo fence="true">)</mo></mrow><mo separator="true">,</mo></mrow><annotation encoding="application/x-tex">\theta_{q,a^*(q)}^*(N) = \underset{\theta}{\arg\max} \left( \mathbb{E}_{y \sim \text{Target}(\theta, N, q)} \left[ \mathbb{1}_{y = y^*(q)} \right] \right),</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.197em;vertical-align:-0.447em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7387em;"><span style="top:-2.428em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mpunct mtight">,</span><span class="mord mtight"><span class="mord mathnormal mtight">a</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6183em;"><span style="top:-2.786em;margin-right:0.0714em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mclose mtight">)</span></span></span></span><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.447em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.7965em;vertical-align:-0.9465em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.4306em;"><span style="top:-2.1535em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop"><span class="mop">ar<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">max</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.9465em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;"><span class="delimsizing size1">(</span></span><span class="mord"><span class="mord mathbb">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.5198em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mrel mtight">∼</span><span class="mord text mtight"><span class="mord mtight">Target</span></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight" style="margin-right:0.10903em;">N</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.3552em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;"><span class="delimsizing size1">[</span></span><span class="mord"><span class="mord">1</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.5198em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mrel mtight">=</span><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6183em;"><span style="top:-2.786em;margin-right:0.0714em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.3552em;"><span></span></span></span></span></span></span><span class="mclose delimcenter" style="top:0em;"><span class="delimsizing size1">]</span></span></span><span class="mclose delimcenter" style="top:0em;"><span class="delimsizing size1">)</span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mpunct">,</span></span></span></span></span></div></figure><p id="15e1384e-bcac-8011-bf10-cc868e25db7c" class="">where <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>y</mi><mo>∗</mo></msup><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">y^*(q)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6887em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="mclose">)</span></span></span></span></span><span></span></span> is the ground-truth for question <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>q</mi></mrow><annotation encoding="application/x-tex">q</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span></span></span></span></span><span></span></span> and <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msubsup><mi>θ</mi><mrow><mi>q</mi><mo separator="true">,</mo><msup><mi>a</mi><mo>∗</mo></msup><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">)</mo></mrow><mo>∗</mo></msubsup><mo stretchy="false">(</mo><mi>N</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\theta_{q,a^*(q)}^*(N)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.247em;vertical-align:-0.497em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6887em;"><span style="top:-2.378em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mpunct mtight">,</span><span class="mord mtight"><span class="mord mathnormal mtight">a</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6183em;"><span style="top:-2.786em;margin-right:0.0714em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mclose mtight">)</span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.497em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mclose">)</span></span></span></span></span><span></span></span> denotes the compute-optimal scaling strategy. Since computing <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msubsup><mi>θ</mi><mrow><mi>q</mi><mo separator="true">,</mo><msup><mi>a</mi><mo>∗</mo></msup><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">)</mo></mrow><mo>∗</mo></msubsup><mo stretchy="false">(</mo><mi>N</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\theta_{q,a^*(q)}^*(N)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.247em;vertical-align:-0.497em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6887em;"><span style="top:-2.378em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mpunct mtight">,</span><span class="mord mtight"><span class="mord mathnormal mtight">a</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6183em;"><span style="top:-2.786em;margin-right:0.0714em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mclose mtight">)</span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.497em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mclose">)</span></span></span></span></span><span></span></span> directly is somewhat tricky, DeepMind proposed an approximation based on the <em><strong>problem difficulty</strong></em>, i.e. allocate test-time compute according to which search strategy achieves best performance for a given difficulty level.</p><p id="15a1384e-bcac-80c9-a276-d5ea8974c543" class="">For example, on simpler problems and lower compute budgets, it is better to use strategies like Best-of-N, while on harder problems, beam search is the better choice. To implement this, for each method we compute the accuracy for a given difficulty level and test-time compute budget. And voila, we now have our compute-optimal curve!</p><figure id="15b1384e-bcac-80b3-bc58-d20ba41d3950" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-opt.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-opt.png"/></a></figure>
154
 
155
  <h2 id="1591384e-bcac-809a-96d2-e928398d159a" class="">Scaling up to larger models</h2><p id="15a1384e-bcac-8078-86d7-f48c2146444e" class="">We also explored scaling up the compute-optimal recipe to Llama 3.2 3B Instruct to see at what point the benefits of the PRM fade in comparison to the policy’s own capacity. To our surprise, compute-optimal scaling works remarkably well, with the 3B model surpassing the performance of Llama 3.1 70B Instruct (22x it's size!):</p><figure id="15b1384e-bcac-80b3-bc58-d20ba41d3950" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-opt-3b.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-opt-3b.png"/></a></figure>
156