Jekyll2020-07-23T11:54:55+02:00http://localhost:4000/narsil.github.io/feed.xmlNarsilSmall experiements insights from ML and software development.Creating a dutch translation app2020-07-22T00:00:00+02:002020-07-22T00:00:00+02:00http://localhost:4000/narsil.github.io/ml/nlp/2020/07/22/creating-a-translate-app<blockquote> <p>TL;DR Recently moved to the Netherlands, in order to avoid Googling translate everything, I did the next best thing to learning the language: I created a clone of translate.google.com</p> </blockquote> <h2 id="find-a-correct-training-loop">Find a correct training loop</h2> <p>My first instinct was to check <a href="https://github.com/huggingface/transformers">Hugging Face</a> as this repo contains solid implementations that I know are easy to change. However, in that particular instance, the example for translation does not start from scratch, and I wanted to check what multilingual translation could do here, as I’m using English, Dutch &amp; French on translate.google.com (For food sometimes french is much better than english for me).</p> <p>My second guess was <a href="https://github.com/pytorch/fairseq">Fairseq</a> from facebook. In their example there is an actual example for multilingual German, French, English. Close enough for my needs. First things first, start to follow the example by the book. Most implementations out there are broken and won’t work out of the box.</p> <p>This time, it turned out particularly smooth. Clone the repo then follow the <a href="https://github.com/pytorch/fairseq/tree/master/examples/translation#multilingual-translation">instructions</a></p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code># First install sacrebleu and sentencepiece pip install sacrebleu sentencepiece # Then download and preprocess the data cd examples/translation/ bash prepare-iwslt17-multilingual.sh cd ../.. # Binarize the de-en dataset TEXT=examples/translation/iwslt17.de_fr.en.bpe16k fairseq-preprocess --source-lang de --target-lang en \ --trainpref $TEXT/train.bpe.de-en \ --validpref $TEXT/valid0.bpe.de-en,$TEXT/valid1.bpe.de-en,$TEXT/valid2.bpe.de-en,$TEXT/valid3.bpe.de-en,$TEXT/valid4.bpe.de-en,$TEXT/valid5.bpe.de-en \ --destdir data-bin/iwslt17.de_fr.en.bpe16k \ --workers 10 # Binarize the fr-en dataset # NOTE: it's important to reuse the en dictionary from the previous step fairseq-preprocess --source-lang fr --target-lang en \ --trainpref $TEXT/train.bpe.fr-en \ --validpref $TEXT/valid0.bpe.fr-en,$TEXT/valid1.bpe.fr-en,$TEXT/valid2.bpe.fr-en,$TEXT/valid3.bpe.fr-en,$TEXT/valid4.bpe.fr-en,$TEXT/valid5.bpe.fr-en \ --tgtdict data-bin/iwslt17.de_fr.en.bpe16k/dict.en.txt \ --destdir data-bin/iwslt17.de_fr.en.bpe16k \ --workers 10 # Train a multilingual transformer model # NOTE: the command below assumes 1 GPU, but accumulates gradients from # 8 fwd/bwd passes to simulate training on 8 GPUs mkdir -p checkpoints/multilingual_transformer CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.de_fr.en.bpe16k/ \ --max-epoch 50 \ --ddp-backend=no_c10d \ --task multilingual_translation --lang-pairs de-en,fr-en \ --arch multilingual_transformer_iwslt_de_en \ --share-decoders --share-decoder-input-output-embed \ --optimizer adam --adam-betas '(0.9, 0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt --min-lr '1e-09' \ --warmup-updates 4000 --warmup-init-lr '1e-07' \ --label-smoothing 0.1 --criterion label_smoothed_cross_entropy \ --dropout 0.3 --weight-decay 0.0001 \ --save-dir checkpoints/multilingual_transformer \ --max-tokens 4000 \ --update-freq 8 # Generate and score the test set with sacrebleu SRC=de sacrebleu --test-set iwslt17 --language-pair ${SRC}-en --echo src \ | python scripts/spm_encode.py --model examples/translation/iwslt17.de_fr.en.bpe16k/sentencepiece.bpe.model \ &gt; iwslt17.test.${SRC}-en.${SRC}.bpe cat iwslt17.test.${SRC}-en.${SRC}.bpe \ | fairseq-interactive data-bin/iwslt17.de_fr.en.bpe16k/ \ --task multilingual_translation --lang-pairs de-en,fr-en \ --source-lang ${SRC} --target-lang en \ --path checkpoints/multilingual_transformer/checkpoint_best.pt \ --buffer-size 2000 --batch-size 128 \ --beam 5 --remove-bpe=sentencepiece \ &gt; iwslt17.test.${SRC}-en.en.sys </code></pre></div></div> <h2 id="the-data">The data</h2> <p>While it’s training, let’s look at where I can get Dutch data. The IWSLT 2017 did not seem to have Dutch data <a href="https://wit3.fbk.eu/mt.php?release=2017-01-trnted">at first glance</a> or <a href="https://wit3.fbk.eu/mt.php?release=2017-01-trnmted">here</a>. I also tried just mimicking the adress from facebook <code class="language-plaintext highlighter-rouge">prepare-iwslt17-multilingual.sh</code> (The address <code class="language-plaintext highlighter-rouge">https://wit3.fbk.eu/archive/2017-01-trnted/texts/de/en/de-en.tgz</code> so I simply tried if <code class="language-plaintext highlighter-rouge">https://wit3.fbk.eu/archive/2017-01-trnted/texts/nl/en/nl-en.tgz</code>). Turns out there aren’t. <a href="https://www.statmt.org/europarl/">Europarl</a> seemed like a good bet but looking at the data, the langage seems pretty formatted and not very dialogue like. That might explain why it does not seem to be used that often. Looking back at IWSLT 2017 finally found the <a href="https://wit3.fbk.eu/mt.php?release=2017-01-mted-test">Dutch data</a> and the <a href="https://wit3.fbk.eu/mt.php?release=2017-01-trnmted">training data</a>. Is it me, or are competitions websites really hard to read ?</p> <h2 id="the-actual-training-loop">The actual training loop</h2> <p>Ok so let’s reuse the training loop from the german file, so we just need to copy the dutch files in the same layout as the german ones, edit all the scripts and command lines to edit everything. I had to multiply the test files, someone Facebook has tst2011, tst2012 tst2013, tst2014, tst2015 for the german data, which does not seem to exist on the competition website… So here instead of trying to figure out where the information was, I simply copy-pasted the tst2010 file into dummy versions for tst2011…tst2015 (oh yeah simply omitting them will make bash scripts fail because file alignement is a requirement !, and I don’t want to spend more than 5mn editing a bash script).</p> <p>Now with our edited bash script:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>cd examples/translation/ bash prepare-iwslt17-multilingual_nl.sh cd ../.. </code></pre></div></div> <p>Preprocess dutch data:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>TEXT=examples/translation/iwslt17.nl.en.bpe16k fairseq-preprocess --source-lang nl --target-lang en \ --trainpref $TEXT/train.bpe.nl-en \ --validpref $TEXT/valid0.bpe.nl-en,$TEXT/valid1.bpe.nl-en,$TEXT/valid2.bpe.nl-en,$TEXT/valid3.bpe.nl-en,$TEXT/valid4.bpe.nl-en,$TEXT/valid5.bpe.nl-en \ --destdir data-bin/iwslt17.nl_fr.en.bpe16k \ --workers 10 </code></pre></div></div> <p>Now let’s preprocess french data</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code># NOTE: it's important to reuse the en dictionary from the previous step fairseq-preprocess --source-lang fr --target-lang en \ --trainpref $TEXT/train.bpe.fr-en \ --validpref $TEXT/valid0.bpe.fr-en,$TEXT/valid1.bpe.fr-en,$TEXT/valid2.bpe.fr-en,$TEXT/valid3.bpe.fr-en,$TEXT/valid4.bpe.fr-en,$TEXT/valid5.bpe.fr-en \ --tgtdict data-bin/iwslt17.nl_fr.en.bpe16k/dict.en.txt \ --destdir data-bin/iwslt17.nl_fr.en.bpe16k \ --workers 10 </code></pre></div></div> <p>Overall, pretty simple task, just a bit bothering to hit all the various walls.</p> <p>Now that we preformatted the dutch data, we can run the training loop on our own data !</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>mkdir -p checkpoints/multilingual_transformer_nl CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.nl_fr.en.bpe16k/ \ --max-epoch 50 \ --ddp-backend=no_c10d \ --task multilingual_translation --lang-pairs nl-en,fr-en \ # Don't change the arch !\ --arch multilingual_transformer_iwslt_de_en \ --share-decoders --share-decoder-input-output-embed \ --optimizer adam --adam-betas '(0.9, 0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt --min-lr '1e-09' \ --warmup-updates 4000 --warmup-init-lr '1e-07' \ --label-smoothing 0.1 --criterion label_smoothed_cross_entropy \ --dropout 0.3 --weight-decay 0.0001 \ # Change the checkpoint \ --save-dir checkpoints/multilingual_transformer_nl \ --max-tokens 4000 \ --update-freq 8 </code></pre></div></div> <h2 id="checking-the-final-result">Checking the final result</h2> <p>So now we have a model <code class="language-plaintext highlighter-rouge">checkpoints/multilingual_transformer_nl/checkpoint_best.pt</code>, let’s run it !</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code># Generate and score the test set with sacrebleu SRC=nl sacrebleu --test-set iwslt17 --language-pair ${SRC}-en --echo src \ | python scripts/spm_encode.py --model examples/translation/iwslt17.nl_fr.en.bpe16k/sentencepiece.bpe.model \ &gt; iwslt17.test.${SRC}-en.${SRC}.bpe cat iwslt17.test.${SRC}-en.${SRC}.bpe \ | fairseq-interactive data-bin/iwslt17.nl_fr.en.bpe16k/ \ --task multilingual_translation --lang-pairs de-en,fr-en \ --source-lang ${SRC} --target-lang en \ --path checkpoints/multilingual_transformer_nl/checkpoint_best.pt \ --buffer-size 2000 --batch-size 128 \ --beam 5 --remove-bpe=sentencepiece \ &gt; iwslt17.test.${SRC}-en.en.sys </code></pre></div></div> <p>But woops…<code class="language-plaintext highlighter-rouge">sacreBLEU: No such language pair "nl-en" sacreBLEU: Available language pairs for test set "iwslt17": en-fr, fr-en, en-de, de-en, en-zh, zh-en, en-ar, ar-en, en-ja, ja-en, en-ko, ko-en</code></p> <p>So it looks like we’re going to need to pipe some of our own data into this pipeline, we can just use the validation set we used to train</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>cat examples/translation/iwslt17.nl_fr.en.bpe16k/valid0.bpe.nl-en.nl | python scripts/spm_encode.py --model examples/translation/iwslt17.nl_fr.en.bpe16k/sentencepiece.bpe.model \ &gt; iwslt17.test.${SRC}-en.${SRC}.bpe </code></pre></div></div> <p>There we go we have encoded with our multilingual BPE tokenizer our valid dataset. We can now run our translating command</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>cat iwslt17.test.${SRC}-en.${SRC}.bpe | fairseq-interactive data-bin/iwslt17.nl_fr.en.bpe16k/ --task multilingual_translation --lang-pairs nl-en,fr-en --source-lang ${SRC} --target-lang en --path checkpoints/multilingual_transformer_nl/checkpoint_best.pt --buffer-size 2000 --batch-size 128 --beam 5 --remove-bpe=sentencepiece </code></pre></div></div> <p>Here are some outputs (not cherry picked):</p> <div class="language-rust highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">S</span><span class="o">-</span><span class="mi">999</span> <span class="n">Iedereen</span> <span class="n">heeft</span> <span class="n">een</span> <span class="n">vissenkom</span> <span class="n">nodig</span><span class="py">. H</span><span class="o">-</span><span class="mi">999</span> <span class="o">-</span><span class="mf">1.0272072553634644</span> <span class="n">Everybody</span> <span class="n">needs</span> <span class="n">a</span> <span class="n">fishing</span> <span class="n">ticket</span><span class="py">. D</span><span class="o">-</span><span class="mi">999</span> <span class="o">-</span><span class="mf">1.0272072553634644</span> <span class="n">Everybody</span> <span class="n">needs</span> <span class="n">a</span> <span class="n">fishing</span> <span class="n">ticket</span><span class="py">. P</span><span class="o">-</span><span class="mi">999</span> <span class="o">-</span><span class="mf">1.5687</span> <span class="o">-</span><span class="mf">0.2169</span> <span class="o">-</span><span class="mf">0.2363</span> <span class="o">-</span><span class="mf">2.0637</span> <span class="o">-</span><span class="mf">2.6527</span> <span class="o">-</span><span class="mf">0.2981</span> <span class="o">-</span><span class="mf">0.1540</span> </code></pre></div></div> <div class="language-rust highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">S</span><span class="o">-</span><span class="mi">998</span> <span class="n">Het</span> <span class="n">leidt</span> <span class="n">tot</span> <span class="n">meer</span> <span class="n">verlamming</span> <span class="n">en</span> <span class="n">minder</span> <span class="n">tevredenheid</span><span class="py">. H</span><span class="o">-</span><span class="mi">998</span> <span class="o">-</span><span class="mf">0.32848915457725525</span> <span class="n">It</span> <span class="n">leads</span> <span class="n">to</span> <span class="n">more</span> <span class="n">paralysis</span> <span class="n">and</span> <span class="n">less</span> <span class="n">satisfaction</span><span class="py">. D</span><span class="o">-</span><span class="mi">998</span> <span class="o">-</span><span class="mf">0.32848915457725525</span> <span class="n">It</span> <span class="n">leads</span> <span class="n">to</span> <span class="n">more</span> <span class="n">paralysis</span> <span class="n">and</span> <span class="n">less</span> <span class="n">satisfaction</span><span class="py">. P</span><span class="o">-</span><span class="mi">998</span> <span class="o">-</span><span class="mf">0.9783</span> <span class="o">-</span><span class="mf">0.3836</span> <span class="o">-</span><span class="mf">0.1854</span> <span class="o">-</span><span class="mf">0.8328</span> <span class="o">-</span><span class="mf">0.1779</span> <span class="o">-</span><span class="mf">0.0163</span> <span class="o">-</span><span class="mf">0.3334</span> <span class="o">-</span><span class="mf">0.3619</span> <span class="o">-</span><span class="mf">0.2152</span> <span class="o">-</span><span class="mf">0.0450</span> <span class="o">-</span><span class="mf">0.2831</span> <span class="o">-</span><span class="mf">0.1289</span> </code></pre></div></div> <div class="language-rust highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">S</span><span class="o">-</span><span class="mi">987</span> <span class="n">Ze</span> <span class="n">maken</span> <span class="n">ons</span> <span class="n">leven</span> <span class="n">minder</span> <span class="n">waard</span><span class="py">. H</span><span class="o">-</span><span class="mi">987</span> <span class="o">-</span><span class="mf">0.5473383665084839</span> <span class="n">They</span> <span class="n">make</span> <span class="n">our</span> <span class="n">lives</span> <span class="n">worth</span> <span class="n">less</span><span class="py">. D</span><span class="o">-</span><span class="mi">987</span> <span class="o">-</span><span class="mf">0.5473383665084839</span> <span class="n">They</span> <span class="n">make</span> <span class="n">our</span> <span class="n">lives</span> <span class="n">worth</span> <span class="n">less</span><span class="err">.</span> </code></pre></div></div> <p>Seems good enough for now.</p> <h2 id="productizing">Productizing</h2> <h3 id="flask-server">Flask server</h3> <p>Ok, in order to productionize, initially I wanted to move away from fairseq, but a lot of logic is actually tied to fairseq-interative (beam search, loading all the args, ensembling the model, source language selection and so on). It’s definitely possible to move out of it, but it felt like a few days job, so much more than I was willing to invest in this particular approach.</p> <p>So the idea is to have a flask server sitting in front of the model, call the appropriate encoding with spm_encode, pass it to fairseq interactive, and output the D-XXX line back to the caller.</p> <p>We’re going to containerize it and deploy to Kubernetes (it just happens I have a kubernetes cluster running, so less problems with deploying on it). I considered using ONNX-js (or TFlite) to deploy directly on the browser which saves a lot of headaches on deployment and keeping the service running in the long run (Like I did for the <a href="https://narsil.github.io/assets/face/">glasses</a> project). Here the main problem is the size of the model (600Mo). I could go back and try to optimize but that’s a pretty big model, it’s going to be hard to make it come to a comfortable level for browser-only mode (Again just too much work for what I have in mind here).</p> <p>So let’s get started from the Flask’s <a href="https://flask.palletsprojects.com/en/1.1.x/quickstart/">hello world</a></p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">flask</span> <span class="kn">import</span> <span class="n">Flask</span> <span class="n">app</span> <span class="o">=</span> <span class="n">Flask</span><span class="p">(</span><span class="n">__name__</span><span class="p">)</span> <span class="o">@</span><span class="n">app</span><span class="o">.</span><span class="n">route</span><span class="p">(</span><span class="s">'/'</span><span class="p">)</span> <span class="k">def</span> <span class="nf">hello_world</span><span class="p">():</span> <span class="k">return</span> <span class="s">'Hello, World!'</span> </code></pre></div></div> <p>Let’s edit it a bit to include our translate function.</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">flask</span> <span class="kn">import</span> <span class="n">Flask</span> <span class="n">app</span> <span class="o">=</span> <span class="n">Flask</span><span class="p">(</span><span class="n">__name__</span><span class="p">)</span> <span class="k">def</span> <span class="nf">translate</span><span class="p">(</span><span class="n">text</span><span class="p">):</span> <span class="c1"># TODO later </span> <span class="k">return</span> <span class="s">"This is a translation !"</span> <span class="o">@</span><span class="n">app</span><span class="o">.</span><span class="n">route</span><span class="p">(</span><span class="s">'/'</span><span class="p">,</span> <span class="n">methods</span><span class="o">=</span><span class="p">[</span><span class="s">"POST"</span><span class="p">])</span> <span class="k">def</span> <span class="nf">hello</span><span class="p">():</span> <span class="n">text</span> <span class="o">=</span> <span class="n">request</span><span class="o">.</span><span class="n">form</span><span class="p">[</span><span class="s">"input"</span><span class="p">]</span> <span class="k">print</span><span class="p">(</span><span class="n">f</span><span class="s">"IN {text}"</span><span class="p">)</span> <span class="n">output</span> <span class="o">=</span> <span class="n">translate</span><span class="p">(</span><span class="n">text</span><span class="p">)</span> <span class="k">print</span><span class="p">(</span><span class="n">f</span><span class="s">"OUT {output}"</span><span class="p">)</span> <span class="n">result</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">({</span><span class="s">"en"</span><span class="p">:</span> <span class="n">output</span><span class="p">})</span> <span class="k">return</span> <span class="n">result</span> </code></pre></div></div> <p>We can run our example and check it’s running with curl</p> <div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nv">$ </span>curl <span class="nt">-d</span> <span class="nv">input</span><span class="o">=</span><span class="s2">"Ik heft een appel."</span> http://localhost:5000/<span class="sb">`</span> <span class="o">{</span><span class="s2">"en"</span>: <span class="s2">"This is a translation !"</span><span class="o">}</span> </code></pre></div></div> <h4 id="implementing-the-translate-function">Implementing the translate function.</h4> <p>Ok this is where we are super tied to fairseq-interactive code, I had to dig into the source code, copy most of it, and mainly split <code class="language-plaintext highlighter-rouge">Model loading</code> code from <code class="language-plaintext highlighter-rouge">Model running</code> code. For that I used a lot of globals as the original code does not separate these two concerns (tidying this will be a later goal if it every comes to that).</p> <p>The final implementation is quite verbose but available <a href="https://github.com/Narsil/translate/blob/master/server/translate.py">here</a>.</p> <p>One good point about this implementation is that we load the model early, so that it’s available right away when the server comes up (but it does take some time to come up). A negative point, is that because it’s loaded eagerly it’s going to make forking a nightmare and basically preventing us from using wsgi efficiently which is the <a href="https://flask.palletsprojects.com/en/1.1.x/deploying/">recommended way of deploying Flask</a>. It’s fine for now, it’s a personnal project after all, to get more stable deployment I would try to remove python from the equation of the web part if possible, it’s really slow and hard to work with on webservers because of the forking/threading nightmare in Python.</p> <p>So know our backend can really translate !</p> <div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nv">$ </span>curl <span class="nt">-d</span> <span class="nv">input</span><span class="o">=</span><span class="s2">"Ik heft een appel."</span> http://localhost:5000/<span class="sb">`</span> <span class="o">{</span><span class="s2">"en"</span>: <span class="s2">"I have an apple."</span><span class="o">}</span> </code></pre></div></div> <p>Before moving that to the cloud, let’s build a nice interface in front of it</p> <h3 id="react-front">React front</h3> <p>Ok so we’re going to use React with Typescript. React because we’re going JS anyway to get the translation without clicking a button with a form like html. It’s also more convenient to use Material-UI which I find helps make a website nice from scratch (and I’m tired of Bootstrap). Typescript because it’s just saner than VanillaJS (it won’t make much of a difference here.</p> <div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>yarn create react-app app <span class="nt">--template</span> typescript <span class="nb">cd </span>app yarn add @material-ui/core </code></pre></div></div> <p>Let’s edit our App.tsx to use Material-UI and get the initial layout looking like <a href="translate.google.com">translate.google.com</a>.</p> <div class="language-typescript highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">import</span> <span class="nx">React</span> <span class="k">from</span> <span class="dl">"</span><span class="s2">react</span><span class="dl">"</span><span class="p">;</span> <span class="k">import</span> <span class="p">{</span> <span class="nx">makeStyles</span> <span class="p">}</span> <span class="k">from</span> <span class="dl">"</span><span class="s2">@material-ui/core/styles</span><span class="dl">"</span><span class="p">;</span> <span class="k">import</span> <span class="nx">TextField</span> <span class="k">from</span> <span class="dl">"</span><span class="s2">@material-ui/core/TextField</span><span class="dl">"</span><span class="p">;</span> <span class="k">import</span> <span class="nx">Card</span> <span class="k">from</span> <span class="dl">"</span><span class="s2">@material-ui/core/Card</span><span class="dl">"</span><span class="p">;</span> <span class="k">import</span> <span class="nx">Grid</span> <span class="k">from</span> <span class="dl">"</span><span class="s2">@material-ui/core/Grid</span><span class="dl">"</span><span class="p">;</span> <span class="kd">const</span> <span class="nx">useStyles</span> <span class="o">=</span> <span class="nx">makeStyles</span><span class="p">(</span><span class="nx">theme</span> <span class="o">=&gt;</span> <span class="p">({</span> <span class="na">app</span><span class="p">:</span> <span class="p">{</span> <span class="na">display</span><span class="p">:</span> <span class="dl">"</span><span class="s2">flex</span><span class="dl">"</span><span class="p">,</span> <span class="na">justifyContent</span><span class="p">:</span> <span class="dl">"</span><span class="s2">center</span><span class="dl">"</span><span class="p">,</span> <span class="na">alignItems</span><span class="p">:</span> <span class="dl">"</span><span class="s2">center</span><span class="dl">"</span><span class="p">,</span> <span class="na">height</span><span class="p">:</span> <span class="dl">"</span><span class="s2">100vh</span><span class="dl">"</span> <span class="p">}</span> <span class="p">}));</span> <span class="kd">function</span> <span class="nx">App</span><span class="p">()</span> <span class="p">{</span> <span class="kd">const</span> <span class="nx">classes</span> <span class="o">=</span> <span class="nx">useStyles</span><span class="p">();</span> <span class="k">return</span> <span class="p">(</span> <span class="o">&lt;</span><span class="nx">div</span> <span class="nx">className</span><span class="o">=</span><span class="p">{</span><span class="nx">classes</span><span class="p">.</span><span class="nx">app</span><span class="p">}</span><span class="o">&gt;</span> <span class="o">&lt;</span><span class="nx">Card</span><span class="o">&gt;</span> <span class="o">&lt;</span><span class="nx">form</span><span class="o">&gt;</span> <span class="o">&lt;</span><span class="nx">Grid</span> <span class="nx">container</span><span class="o">&gt;</span> <span class="o">&lt;</span><span class="nx">Grid</span> <span class="nx">item</span> <span class="nx">xs</span><span class="o">=</span><span class="p">{</span><span class="mi">12</span><span class="p">}</span> <span class="nx">md</span><span class="o">=</span><span class="p">{</span><span class="mi">6</span><span class="p">}</span><span class="o">&gt;</span> <span class="o">&lt;</span><span class="nx">TextField</span> <span class="nx">id</span><span class="o">=</span><span class="dl">"</span><span class="s2">standard-basic</span><span class="dl">"</span> <span class="nx">label</span><span class="o">=</span><span class="dl">"</span><span class="s2">Dutch</span><span class="dl">"</span> <span class="nx">multiline</span> <span class="nx">autoFocus</span> <span class="o">/&gt;</span> <span class="o">&lt;</span><span class="sr">/Grid</span><span class="err">&gt; </span> <span class="o">&lt;</span><span class="nx">Grid</span> <span class="nx">item</span> <span class="nx">xs</span><span class="o">=</span><span class="p">{</span><span class="mi">12</span><span class="p">}</span> <span class="nx">md</span><span class="o">=</span><span class="p">{</span><span class="mi">6</span><span class="p">}</span><span class="o">&gt;</span> <span class="o">&lt;</span><span class="nx">TextField</span> <span class="nx">id</span><span class="o">=</span><span class="dl">"</span><span class="s2">standard-basic</span><span class="dl">"</span> <span class="nx">label</span><span class="o">=</span><span class="dl">"</span><span class="s2">English</span><span class="dl">"</span> <span class="nx">multiline</span> <span class="o">/&gt;</span> <span class="o">&lt;</span><span class="sr">/Grid</span><span class="err">&gt; </span> <span class="o">&lt;</span><span class="sr">/Grid</span><span class="err">&gt; </span> <span class="o">&lt;</span><span class="sr">/form</span><span class="err">&gt; </span> <span class="o">&lt;</span><span class="sr">/Card</span><span class="err">&gt; </span> <span class="o">&lt;</span><span class="sr">/div</span><span class="err">&gt; </span> <span class="p">);</span> <span class="p">}</span> <span class="k">export</span> <span class="k">default</span> <span class="nx">App</span><span class="p">;</span> </code></pre></div></div> <p>Here is the result : <img src="https://i.imgur.com/ZszCVQU.png" alt="" /></p> <p>Now let’s look at the logic (simplified):</p> <div class="language-typescript highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nx">type</span> <span class="nx">Timeout</span> <span class="o">=</span> <span class="nx">ReturnType</span><span class="o">&lt;</span><span class="k">typeof</span> <span class="nx">setTimeout</span><span class="o">&gt;</span><span class="p">;</span> <span class="kd">const</span> <span class="p">[</span><span class="nx">text</span><span class="p">,</span> <span class="nx">setText</span><span class="p">]</span> <span class="o">=</span> <span class="nx">useState</span><span class="p">(</span><span class="dl">""</span><span class="p">);</span> <span class="kd">const</span> <span class="p">[</span><span class="nx">time</span><span class="p">,</span> <span class="nx">setTime</span><span class="p">]</span> <span class="o">=</span> <span class="nx">useState</span><span class="o">&lt;</span><span class="nx">Timeout</span> <span class="o">|</span> <span class="kc">null</span><span class="o">&gt;</span><span class="p">(</span><span class="kc">null</span><span class="p">);</span> <span class="kd">const</span> <span class="nx">url</span> <span class="o">=</span> <span class="dl">"</span><span class="s2">http://localhost:5000</span><span class="dl">"</span><span class="p">;</span> <span class="kd">const</span> <span class="nx">translate</span> <span class="o">=</span> <span class="p">(</span><span class="nx">text</span><span class="p">:</span> <span class="nx">string</span><span class="p">)</span> <span class="o">=&gt;</span> <span class="p">{</span> <span class="k">if</span> <span class="p">(</span><span class="nx">text</span> <span class="o">===</span> <span class="dl">""</span><span class="p">)</span> <span class="p">{</span> <span class="nx">setText</span><span class="p">(</span><span class="dl">""</span><span class="p">);</span> <span class="k">return</span><span class="p">;</span> <span class="p">}</span> <span class="kd">const</span> <span class="nx">form</span> <span class="o">=</span> <span class="k">new</span> <span class="nx">FormData</span><span class="p">();</span> <span class="nx">form</span><span class="p">.</span><span class="nx">append</span><span class="p">(</span><span class="dl">"</span><span class="s2">input</span><span class="dl">"</span><span class="p">,</span> <span class="nx">text</span><span class="p">);</span> <span class="nx">fetch</span><span class="p">(</span><span class="nx">url</span><span class="p">,</span> <span class="p">{</span> <span class="na">method</span><span class="p">:</span> <span class="dl">"</span><span class="s2">POST</span><span class="dl">"</span><span class="p">,</span> <span class="na">body</span><span class="p">:</span> <span class="nx">form</span> <span class="p">}).</span><span class="nx">then</span><span class="p">(</span><span class="nx">response</span> <span class="o">=&gt;</span> <span class="p">{</span> <span class="nx">response</span><span class="p">.</span><span class="nx">json</span><span class="p">().</span><span class="nx">then</span><span class="p">(</span><span class="nx">json</span> <span class="o">=&gt;</span> <span class="p">{</span> <span class="nx">console</span><span class="p">.</span><span class="nx">log</span><span class="p">(</span><span class="nx">json</span><span class="p">);</span> <span class="nx">setText</span><span class="p">(</span><span class="nx">json</span><span class="p">[</span><span class="dl">"</span><span class="s2">en</span><span class="dl">"</span><span class="p">]);</span> <span class="p">});</span> <span class="p">});</span> <span class="p">};</span> </code></pre></div></div> <p>Then call it on the <code class="language-plaintext highlighter-rouge">onChange</code> attribute of our Dutch field.</p> <div class="language-typescript highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nx">onChange</span><span class="o">=</span><span class="p">{</span><span class="nx">event</span> <span class="o">=&gt;</span> <span class="p">{</span> <span class="c1">// We use a timeout handler to prevent very fast keystrokes</span> <span class="c1">// from spamming our server.</span> <span class="k">if</span> <span class="p">(</span><span class="nx">time</span> <span class="o">!==</span> <span class="kc">null</span><span class="p">)</span> <span class="p">{</span> <span class="nx">clearTimeout</span><span class="p">(</span><span class="nx">time</span><span class="p">);</span> <span class="p">}</span> <span class="kd">const</span> <span class="nx">text</span> <span class="o">=</span> <span class="nx">event</span><span class="p">.</span><span class="nx">target</span><span class="p">.</span><span class="nx">value</span><span class="p">;</span> <span class="kd">const</span> <span class="nx">timeout</span> <span class="o">=</span> <span class="nx">setTimeout</span><span class="p">(()</span> <span class="o">=&gt;</span> <span class="p">{</span> <span class="nx">translate</span><span class="p">(</span><span class="nx">text</span><span class="p">);</span> <span class="p">},</span> <span class="mi">500</span><span class="p">);</span> <span class="nx">setTime</span><span class="p">(</span><span class="nx">timeout</span><span class="p">);</span> <span class="p">}}</span> </code></pre></div></div> <p>There we have it:</p> <p><img src="https://i.imgur.com/EYZ0EWR.gif" alt="" /></p> <h3 id="lets-dockerize-">Let’s dockerize !</h3> <p>As I mentionned loading the whole model in the flask app is going to hinder a lot the wsgi process forking. I did try it, try to come up with easy fixes, but ultimately found that keeping the development server was just easier.</p> <p>Ok so we’re going to need a python docker image, install pytorch, fairseq, and flask to our image (actually we need flask_cors too to make sure we can call from any website as it’s an API.)</p> <p>As it turns out, fairseq 0.9 had a bug in the training loop and I was using master from a few month ago, and I needed to work with that specific version since there had been breaking changes since in master. That gives us the following <code class="language-plaintext highlighter-rouge">requirements.txt</code></p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>torch flask flask_cors -e git://github.com/pytorch/fairseq.git@7a6519f84fed06947bbf161c7b66c9099bc4ce53#egg=fairseq sentencepiece </code></pre></div></div> <p>Now our Docker file, is going to get the python dependencies, copy all the local files (including model and tokenizer file) and run the flask server. That gives us :</p> <pre><code class="language-Dockerfile">FROM python:3.7-slim RUN pip install -U pip RUN apt-get update &amp;&amp; apt-get install -y git build-essential # Required for building fairseq from source. COPY server/requirements.txt /app/requirements.txt RUN pip install -r /app/requirements.txt COPY . /app WORKDIR /app CMD ["python", "translate.py"] </code></pre> <p>Let’s build and check that it works:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>docker build -t translate:latest . docker run -p 5000:5000 translate:latest # Now check with curl that we can still hit the docker and get a correct answer curl -d input="Ik heft een appel." http://localhost:5000/` # {"en": "This is a translation !"} </code></pre></div></div> <h3 id="kubernetes-cluster">Kubernetes cluster</h3> <p>Okay the following part will be pretty specific to my setup. I use a kubernetes cluster on GCP with ingress. I’m going to skip updating the SSL certificate.</p> <p>Let’s start with pushing the image to GCP:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>docker tag translate:latest gcr.io/myproject-XXXXXX/translate:1.0 docker push gcr.io/myproject-XXXXXX/translate:1.0 kubectl apply -f deployment.yaml kubectl apply -f service.yaml kubectl apply -f ingress.yaml </code></pre></div></div> <p>Here are the (edited for brevity&amp;security) service files I used:</p> <div class="language-yaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">#deployment.yaml</span> <span class="na">apiVersion</span><span class="pi">:</span> <span class="s">apps/v1</span> <span class="na">kind</span><span class="pi">:</span> <span class="s">Deployment</span> <span class="na">metadata</span><span class="pi">:</span> <span class="na">name</span><span class="pi">:</span> <span class="s">translate-deployment</span> <span class="na">spec</span><span class="pi">:</span> <span class="na">replicas</span><span class="pi">:</span> <span class="m">1</span> <span class="na">selector</span><span class="pi">:</span> <span class="na">matchLabels</span><span class="pi">:</span> <span class="na">app</span><span class="pi">:</span> <span class="s">translate</span> <span class="na">template</span><span class="pi">:</span> <span class="na">metadata</span><span class="pi">:</span> <span class="na">labels</span><span class="pi">:</span> <span class="na">app</span><span class="pi">:</span> <span class="s">translate</span> <span class="na">spec</span><span class="pi">:</span> <span class="na">containers</span><span class="pi">:</span> <span class="pi">-</span> <span class="na">name</span><span class="pi">:</span> <span class="s">translate</span> <span class="na">image</span><span class="pi">:</span> <span class="s">gcr.io/myproject-XXXXX/translate:1.0</span> <span class="na">ports</span><span class="pi">:</span> <span class="pi">-</span> <span class="na">containerPort</span><span class="pi">:</span> <span class="m">5000</span> </code></pre></div></div> <div class="language-yaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># service.yaml</span> <span class="na">apiVersion</span><span class="pi">:</span> <span class="s">v1</span> <span class="na">kind</span><span class="pi">:</span> <span class="s">Service</span> <span class="na">metadata</span><span class="pi">:</span> <span class="na">name</span><span class="pi">:</span> <span class="s">translate-service</span> <span class="na">spec</span><span class="pi">:</span> <span class="na">type</span><span class="pi">:</span> <span class="s">NodePort</span> <span class="na">selector</span><span class="pi">:</span> <span class="na">app</span><span class="pi">:</span> <span class="s">translate</span> <span class="na">ports</span><span class="pi">:</span> <span class="pi">-</span> <span class="na">protocol</span><span class="pi">:</span> <span class="s">TCP</span> <span class="na">port</span><span class="pi">:</span> <span class="m">80</span> <span class="na">targetPort</span><span class="pi">:</span> <span class="m">5000</span> </code></pre></div></div> <div class="language-yaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">#ingress.yaml</span> <span class="na">apiVersion</span><span class="pi">:</span> <span class="s">networking.k8s.io/v1beta1</span> <span class="na">kind</span><span class="pi">:</span> <span class="s">Ingress</span> <span class="na">metadata</span><span class="pi">:</span> <span class="na">name</span><span class="pi">:</span> <span class="s">ingress-front</span> <span class="na">annotations</span><span class="pi">:</span> <span class="s">kubernetes.io/ingress.global-static-ip-name</span><span class="pi">:</span> <span class="s">address-cluster</span> <span class="s">networking.gke.io/managed-certificates</span><span class="pi">:</span> <span class="s">ottomate-certificate-new</span> <span class="na">spec</span><span class="pi">:</span> <span class="na">rules</span><span class="pi">:</span> <span class="pi">-</span> <span class="na">host</span><span class="pi">:</span> <span class="s">translate.ottomate.app</span> <span class="na">http</span><span class="pi">:</span> <span class="na">paths</span><span class="pi">:</span> <span class="pi">-</span> <span class="na">path</span><span class="pi">:</span> <span class="s">/*</span> <span class="na">backend</span><span class="pi">:</span> <span class="na">serviceName</span><span class="pi">:</span> <span class="s">translate-service</span> <span class="na">servicePort</span><span class="pi">:</span> <span class="m">80</span> </code></pre></div></div> <p>Hopefully within a few minutes you have your pod running and you can hit your live own server with the API.</p> <p>You just need to update your react App to point the the correct URL and boom your done, your very own translate server app.</p> <h4 id="what-couldshould-be-done-next">What could/should be done next.</h4> <p>For the model:</p> <ul> <li>Add more data to the original training set, some words are missing, translation can become funky on some real world sentences I give the machine (Dutch companies tend to send very verbose emails)</li> <li>Add some data augmentation in the pool as the current translation is very brittle to errors. Using Sentence piece algorihm with sampling instead of BPE could be used, some typo generator, word inversions to name a few. Training some error detection algorithm on top or using ready made ones could help (translate.google.com has some spellfixing magic applied before it seems.)</li> <li>Making it smaller to make it portable to tflite, mobile phone for offline mode and so on (it’s a pretty big workload to make it work though)</li> </ul> <p>For the backend:</p> <ul> <li>Battle testing the backend should be the first thing to do to check failure modes and fix naive DOS attacks.</li> <li>Something like <a href="https://github.com/pytorch/serve">TorchServe</a> seems like what we want for the model part. Never used it so far, but it seems to solve some problems encountered here and would make iterations faster on various models (also swapping out models).</li> <li>On the other spectrum I could go for tighter control. Removing the fairseq-interative clutter would be my first move. If I can go pytorch barebones, then using Rust, with Hugging Face’s <a href="https://github.com/huggingface/tokenizers">tokenizers</a> library would probably make inference faster and deployment easier. It would of course make iteration much slower so I would do that only when the model is very stable. It could make mobile offline possible (with a very large app data but doable.)</li> </ul> <p>For the frontend:</p> <ul> <li>Working a bit more on the mobile part of the design which is a bit broken at the moment.</li> <li>Maybe add buttons to switch languages easily, switch language sides (although I mostly use Dutch-&gt;English and Dutch-&gt;French)</li> <li>Add a react-native app so that I can translate from my phone. (Without offline mode)</li> </ul>nicolasTL;DR Recently moved to the Netherlands, in order to avoid Googling translate everything, I did the next best thing to learning the language: I created a clone of translate.google.comSuper simple estimation of available solar energy2020-03-19T00:00:00+01:002020-03-19T00:00:00+01:00http://localhost:4000/narsil.github.io/energy/2020/03/19/solar-energy<!-- ################################################# ### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ### ################################################# # file to edit: _notebooks/2020-03-19-solar-energy.ipynb --> <div class="container" id="notebook-container"> <div class="cell border-box-sizing code_cell rendered"> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h1 id="Solar-energy">Solar energy<a class="anchor-link" href="#Solar-energy"> </a></h1><h2 id="Stefan-boltzmann's-law">Stefan boltzmann's law<a class="anchor-link" href="#Stefan-boltzmann's-law"> </a></h2><p>$ \text{Surface energy} = \sigma T^4$</p> <p>For the sun, $T = \text{5,778 }K$</p> <p>$\sigma = 5.67 \times 10 ^{-8} W.m^{-2}.K^{-4}$</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="kn">from</span> <span class="nn">sympy.physics.units</span> <span class="kn">import</span> <span class="n">K</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">giga</span> <span class="n">sigma</span> <span class="o">=</span> <span class="mf">5.67</span> <span class="o">*</span> <span class="mi">10</span><span class="o">**</span><span class="p">(</span><span class="o">-</span><span class="mi">8</span><span class="p">)</span> <span class="o">*</span> <span class="n">W</span> <span class="o">*</span><span class="n">m</span><span class="o">**</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="n">K</span><span class="o">**</span><span class="p">(</span><span class="o">-</span><span class="mi">4</span><span class="p">)</span> <span class="n">T</span> <span class="o">=</span> <span class="mi">5778</span> <span class="o">*</span> <span class="n">K</span> <span class="n">surface_energy</span> <span class="o">=</span> <span class="n">sigma</span> <span class="o">*</span> <span class="n">T</span><span class="o">**</span><span class="mi">4</span> <span class="nb">print</span><span class="p">(</span><span class="n">surface_energy</span><span class="p">)</span> </pre></div> </div> </div> </div> <div class="output_wrapper"> <div class="output"> <div class="output_area"> <div class="output_subarea output_stream output_stdout output_text"> <pre>63196526.5460292*watt/meter**2 </pre> </div> </div> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h2 id="Total-emitted-solar-energy">Total emitted solar energy<a class="anchor-link" href="#Total-emitted-solar-energy"> </a></h2><p>$ Radiation = \text{Surface of the sun} \times \text{Surface energy} $</p> <p>$ Radiation = 4 \pi r^2 \times \text{Surface energy} $</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="kn">from</span> <span class="nn">sympy</span> <span class="kn">import</span> <span class="o">*</span> <span class="n">r_sun</span> <span class="o">=</span> <span class="mi">696_340</span> <span class="o">*</span> <span class="mi">1000</span> <span class="o">*</span><span class="n">m</span> <span class="n">surface_of_sun</span> <span class="o">=</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">pi</span> <span class="o">*</span> <span class="n">r_sun</span> <span class="o">**</span> <span class="mi">2</span> <span class="n">radiation</span> <span class="o">=</span> <span class="n">surface_of_sun</span> <span class="o">*</span> <span class="n">surface_energy</span> <span class="nb">print</span><span class="p">(</span><span class="n">radiation</span><span class="p">)</span> </pre></div> </div> </div> </div> <div class="output_wrapper"> <div class="output"> <div class="output_area"> <div class="output_subarea output_stream output_stdout output_text"> <pre>1.22573302243694e+26*pi*watt </pre> </div> </div> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h2 id="Energy-received-at-earth-average-distance">Energy received at earth average distance<a class="anchor-link" href="#Energy-received-at-earth-average-distance"> </a></h2><p>$ \text{Radiation received} = \frac{\text{Total sun radiation}}{ \text{sphere at earth's distance}}$</p> <p>$ \text{Radiation received} = \frac{Radiation}{ 4 \pi D_{earth-sun}^2} $</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="n">R_earth</span> <span class="o">=</span> <span class="mi">6_371</span> <span class="o">*</span> <span class="mi">1000</span> <span class="o">*</span> <span class="n">m</span> <span class="n">D_earth_sun</span> <span class="o">=</span> <span class="mf">148.88</span> <span class="o">*</span> <span class="mi">10</span><span class="o">**</span><span class="mi">6</span> <span class="o">*</span> <span class="mi">1000</span> <span class="o">*</span> <span class="n">m</span> <span class="n">earth_perp_surface</span> <span class="o">=</span> <span class="n">pi</span> <span class="o">*</span> <span class="n">R_earth</span> <span class="o">**</span><span class="mi">2</span> <span class="n">sphere</span> <span class="o">=</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">pi</span> <span class="o">*</span> <span class="n">D_earth_sun</span> <span class="o">**</span><span class="mi">2</span> <span class="n">radiation_received</span> <span class="o">=</span> <span class="n">radiation</span> <span class="o">/</span> <span class="n">sphere</span> <span class="nb">print</span><span class="p">(</span><span class="n">radiation_received</span><span class="p">)</span> </pre></div> </div> </div> </div> <div class="output_wrapper"> <div class="output"> <div class="output_area"> <div class="output_subarea output_stream output_stdout output_text"> <pre>1382.49374484614*watt/meter**2 </pre> </div> </div> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h2 id="Energy-received-by-the-earth-surface-(before-atmosphere)">Energy received by the earth surface (before atmosphere)<a class="anchor-link" href="#Energy-received-by-the-earth-surface-(before-atmosphere)"> </a></h2><p>$ \text{Energy received} = \text{radiation received} \times \frac{ \text{visible surface}}{ \text{earth's surface}} $</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="n">power_received</span> <span class="o">=</span> <span class="n">radiation_received</span> <span class="o">*</span> <span class="n">pi</span> <span class="o">*</span> <span class="n">R_earth</span> <span class="o">**</span><span class="mi">2</span> <span class="n">surface_power_received</span> <span class="o">=</span> <span class="n">power_received</span> <span class="o">/</span> <span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="n">pi</span> <span class="o">*</span> <span class="n">R_earth</span> <span class="o">**</span><span class="mi">2</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">surface_power_received</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">power_received</span><span class="o">.</span><span class="n">n</span><span class="p">())</span> </pre></div> </div> </div> </div> <div class="output_wrapper"> <div class="output"> <div class="output_area"> <div class="output_subarea output_stream output_stdout output_text"> <pre>345.623436211536*watt/meter**2 1.76290235470883e+17*watt </pre> </div> </div> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <blockquote><p>RADIATION RECEIVED BY SYSTEM EARTH = $345 W.m^{-2}$</p> <p>MAXIMUM POWER WITH EARTH "DYSON SPHERE": $176 PW$</p> </blockquote> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h1 id="Human-consumption">Human consumption<a class="anchor-link" href="#Human-consumption"> </a></h1><p>13 511 MTep <a href="https://www.iea.org/data-and-statistics?country=WORLD&amp;fuel=Energy%20supply&amp;indicator=Total%20primary%20energy%20supply%20%28TPES%29%20by%20source">Source International Energy agency</a></p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="kn">from</span> <span class="nn">sympy.physics.units</span> <span class="kn">import</span> <span class="n">J</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="n">W</span> <span class="kn">from</span> <span class="nn">sympy.physics.units.util</span> <span class="kn">import</span> <span class="n">convert_to</span> <span class="n">million</span> <span class="o">=</span> <span class="mi">10</span> <span class="o">**</span><span class="mi">6</span> <span class="n">kilo</span> <span class="o">=</span> <span class="mi">10</span><span class="o">**</span><span class="mi">3</span> <span class="n">giga</span> <span class="o">=</span> <span class="mi">10</span> <span class="o">**</span> <span class="mi">9</span> <span class="n">toe</span> <span class="o">=</span> <span class="mf">41.868</span> <span class="o">*</span> <span class="n">giga</span> <span class="o">*</span> <span class="n">J</span> <span class="n">ktoe</span> <span class="o">=</span> <span class="n">kilo</span> <span class="o">*</span> <span class="n">toe</span> <span class="n">Mtoe</span> <span class="o">=</span> <span class="n">million</span> <span class="o">*</span> <span class="n">toe</span> <span class="n">hour</span> <span class="o">=</span> <span class="mi">60</span> <span class="o">*</span> <span class="mi">60</span> <span class="o">*</span> <span class="n">s</span> <span class="n">year</span> <span class="o">=</span> <span class="mi">24</span> <span class="o">*</span> <span class="n">h</span> <span class="o">*</span> <span class="mf">365.25</span> <span class="n">base</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">([</span><span class="mi">3852538</span><span class="p">,</span><span class="mi">2949909</span><span class="p">,</span><span class="mi">670298</span><span class="p">,</span><span class="mi">335519</span><span class="p">,</span><span class="mi">204190</span><span class="p">,</span><span class="mi">1286064</span><span class="p">,</span><span class="mi">4329220</span><span class="p">])</span> <span class="n">Humanity_total_annual_consumption</span> <span class="o">=</span> <span class="n">base</span> <span class="o">*</span> <span class="n">ktoe</span> <span class="n">humanity_power_consumption</span> <span class="o">=</span> <span class="n">Humanity_total_annual_consumption</span> <span class="o">/</span> <span class="n">year</span> <span class="nb">print</span><span class="p">(</span><span class="n">convert_to</span><span class="p">(</span><span class="n">humanity_power_consumption</span><span class="o">.</span><span class="n">n</span><span class="p">(),</span> <span class="p">[</span><span class="n">W</span><span class="p">])</span><span class="o">.</span><span class="n">n</span><span class="p">())</span> </pre></div> </div> </div> </div> <div class="output_wrapper"> <div class="output"> <div class="output_area"> <div class="output_subarea output_stream output_stdout output_text"> <pre>18080149776408.9*watt </pre> </div> </div> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="n">convert_to</span><span class="p">(</span><span class="n">humanity_power_consumption</span> <span class="o">/</span> <span class="n">power_received</span> <span class="o">*</span> <span class="mi">100</span><span class="p">,</span> <span class="p">[</span><span class="n">J</span><span class="p">,</span> <span class="n">s</span><span class="p">])</span><span class="o">.</span><span class="n">n</span><span class="p">())</span> </pre></div> </div> </div> </div> <div class="output_wrapper"> <div class="output"> <div class="output_area"> <div class="output_subarea output_stream output_stdout output_text"> <pre>0.0102558997258785 </pre> </div> </div> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>We are currently consuming <strong>0.01% of the maximum capacity of the earth covered by a Dyson sphere of solar panels</strong>.</p> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h3 id="A-bit-more-realistic-approach">A bit more realistic approach<a class="anchor-link" href="#A-bit-more-realistic-approach"> </a></h3> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>After the atmosphere only $168 W.m^{-2}$ hit the surface. It's quite complicated to infer it depends on the wavelength of the incoming light, clouds, composition of the atmosphere and so on, so we just take the value from <a href="https://fr.wikipedia.org/wiki/Bilan_radiatif_de_la_Terre">here</a>.</p> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>Then we only have 29% of the earth surface that is landmass (where we can reasonably put solar panels in large quantity)</p> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>Of that 31% is covered in forest which are already some natural solar panels we don't want to remove (for other obvious reasons) <a href="http://www.earth-policy.org/indicators/C56/forests_2012">source</a> And 38.4% is covered of agricultural land <a href="https://en.wikipedia.org/wiki/Agricultural_land">source</a>.</p> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>Then solar panels are not 100% efficient. They are roughly only 20% efficient with current technology at a reasonable cost.</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="n">earth_power_received</span> <span class="o">=</span> <span class="mi">168</span> <span class="o">*</span> <span class="n">W</span> <span class="o">*</span> <span class="n">m</span> <span class="o">**</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">)</span> <span class="n">available_surface</span> <span class="o">=</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">pi</span> <span class="o">*</span> <span class="n">R_earth</span> <span class="o">**</span><span class="mi">2</span> <span class="o">*</span> <span class="mf">0.29</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-.</span><span class="mi">31</span> <span class="o">-</span> <span class="o">.</span><span class="mi">384</span><span class="p">)</span> <span class="n">max_power</span> <span class="o">=</span> <span class="n">earth_power_received</span> <span class="o">*</span> <span class="n">available_surface</span> <span class="o">*</span> <span class="mf">0.2</span> <span class="nb">print</span><span class="p">(</span><span class="n">max_power</span><span class="o">.</span><span class="n">n</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="n">convert_to</span><span class="p">(</span><span class="n">humanity_power_consumption</span> <span class="o">/</span> <span class="n">max_power</span> <span class="o">*</span><span class="mi">100</span><span class="p">,</span> <span class="p">[</span><span class="n">J</span><span class="p">,</span> <span class="n">s</span><span class="p">])</span><span class="o">.</span><span class="n">n</span><span class="p">())</span> </pre></div> </div> </div> </div> <div class="output_wrapper"> <div class="output"> <div class="output_area"> <div class="output_subarea output_stream output_stdout output_text"> <pre>1.52084087357243e+15*watt 1.18882587196246 </pre> </div> </div> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h1 id="Conclusion">Conclusion<a class="anchor-link" href="#Conclusion"> </a></h1> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>In the end we are currently consuming <strong>1.2% of the realistic available solar power energy</strong>. That's would require posing solar panels everywhere on the planet that is not a forest or agricultural land. And we don't account yet for Energy return on energy invested (EROEI) which is likely to increase that percentage.</p> <p>NB: This is a very superficial attempt to evaluate these numbers, however the result should be correct within an order of magnitude.</p> </div> </div> </div> </div>Can we train neural networks without gradient descent ?2020-03-10T00:00:00+01:002020-03-10T00:00:00+01:00http://localhost:4000/narsil.github.io/ml/2020/03/10/no-gd-training<!-- ################################################# ### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ### ################################################# # file to edit: _notebooks/2020-03-10-no-gd-training.ipynb --> <div class="container" id="notebook-container"> <div class="cell border-box-sizing code_cell rendered"> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h2 id="What's-the-problem-?">What's the problem ?<a class="anchor-link" href="#What's-the-problem-?"> </a></h2><p>ML models usually are not really capable of predicting how well the data you<br /> feed them is close to what was in the dataset. It really matters in production models as they might make really stupid mistakes just because they are off<br /> the training set.</p> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>Let's train a simple mnist model (straight out from pytorch tutorial <a href="https://github.com/pytorch/examples/tree/master/mnist">https://github.com/pytorch/examples/tree/master/mnist</a>)</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <details class="description"> <summary class="btn btn-sm" data-open="Hide Code" data-close="Show Code"></summary> <p><div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="c1">#collapse</span> <span class="kn">from</span> <span class="nn">__future__</span> <span class="kn">import</span> <span class="n">print_function</span> <span class="kn">import</span> <span class="nn">argparse</span> <span class="kn">import</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span> <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span> <span class="kn">import</span> <span class="nn">torch.optim</span> <span class="k">as</span> <span class="nn">optim</span> <span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">datasets</span><span class="p">,</span> <span class="n">transforms</span> <span class="kn">from</span> <span class="nn">torch.optim.lr_scheduler</span> <span class="kn">import</span> <span class="n">StepLR</span> <span class="kn">import</span> <span class="nn">os</span> <span class="k">class</span> <span class="nc">Net</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="nb">super</span><span class="p">(</span><span class="n">Net</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout2d</span><span class="p">(</span><span class="mf">0.25</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout2d</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">9216</span><span class="p">,</span> <span class="mi">128</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">max_pool2d</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">output</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">log_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="k">return</span> <span class="n">output</span> <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">):</span> <span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span> <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_loader</span><span class="p">):</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">target</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span> <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="k">if</span> <span class="n">batch_idx</span> <span class="o">%</span> <span class="n">args</span><span class="o">.</span><span class="n">log_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Train Epoch: </span><span class="si">{}</span><span class="s1"> [</span><span class="si">{}</span><span class="s1">/</span><span class="si">{}</span><span class="s1"> (</span><span class="si">{:.0f}</span><span class="s1">%)]</span><span class="se">\t</span><span class="s1">Loss: </span><span class="si">{:.6f}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">batch_idx</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">),</span> <span class="mf">100.</span> <span class="o">*</span> <span class="n">batch_idx</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="p">),</span> <span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()))</span> <span class="k">def</span> <span class="nf">test</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">):</span> <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span> <span class="n">test_loss</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">correct</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span> <span class="k">for</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="ow">in</span> <span class="n">test_loader</span><span class="p">:</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">target</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="n">test_loss</span> <span class="o">+=</span> <span class="n">F</span><span class="o">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;sum&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="c1"># sum up batch loss</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># get the index of the max log-probability</span> <span class="n">correct</span> <span class="o">+=</span> <span class="n">pred</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">target</span><span class="o">.</span><span class="n">view_as</span><span class="p">(</span><span class="n">pred</span><span class="p">))</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="n">test_loss</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">Test set: Average loss: </span><span class="si">{:.4f}</span><span class="s1">, Accuracy: </span><span class="si">{}</span><span class="s1">/</span><span class="si">{}</span><span class="s1"> (</span><span class="si">{:.0f}</span><span class="s1">%)</span><span class="se">\n</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="n">test_loss</span><span class="p">,</span> <span class="n">correct</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">),</span> <span class="mf">100.</span> <span class="o">*</span> <span class="n">correct</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">)))</span> <span class="k">def</span> <span class="nf">mnist</span><span class="p">():</span> <span class="n">filename</span> <span class="o">=</span><span class="s2">&quot;mnist_cnn.pt&quot;</span> <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">filename</span><span class="p">):</span> <span class="k">return</span> <span class="c1"># Training settings</span> <span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">(</span><span class="n">description</span><span class="o">=</span><span class="s1">&#39;PyTorch MNIST Example&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--batch-size&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;input batch size for training (default: 64)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--test-batch-size&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;input batch size for testing (default: 1000)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--epochs&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;number of epochs to train (default: 14)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--lr&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;LR&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;learning rate (default: 1.0)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--gamma&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mf">0.7</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;M&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;Learning rate step gamma (default: 0.7)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--no-cuda&#39;</span><span class="p">,</span> <span class="n">action</span><span class="o">=</span><span class="s1">&#39;store_true&#39;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;disables CUDA training&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--seed&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;S&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;random seed (default: 1)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--log-interval&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;how many batches to wait before logging training status&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--save-model&#39;</span><span class="p">,</span> <span class="n">action</span><span class="o">=</span><span class="s1">&#39;store_true&#39;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;For Saving the current Model&#39;</span><span class="p">)</span> <span class="n">args</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">parse_args</span><span class="p">()</span> <span class="n">use_cuda</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">args</span><span class="o">.</span><span class="n">no_cuda</span> <span class="ow">and</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">args</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cuda&quot;</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="s2">&quot;cpu&quot;</span><span class="p">)</span> <span class="n">kwargs</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;num_workers&#39;</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s1">&#39;pin_memory&#39;</span><span class="p">:</span> <span class="kc">True</span><span class="p">}</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="p">{}</span> <span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">expanduser</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">),</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">test_batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">Adadelta</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">lr</span><span class="p">)</span> <span class="n">scheduler</span> <span class="o">=</span> <span class="n">StepLR</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">step_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">gamma</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">gamma</span><span class="p">)</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">args</span><span class="o">.</span><span class="n">epochs</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span> <span class="n">train</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span> <span class="n">test</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">)</span> <span class="n">scheduler</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="k">if</span> <span class="n">args</span><span class="o">.</span><span class="n">save_model</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">filename</span><span class="p">)</span> <span class="c1"># mnist()</span> </pre></div> </div> </div> </div> </p> </details> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>Other out of distribution detector have been proposed. Here is a sample of methods:</p> <ul> <li>Genetic algorithms</li> <li>DFO</li> <li>Simulated annealing</li> </ul> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h2 id="Experiments">Experiments<a class="anchor-link" href="#Experiments"> </a></h2> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="k">def</span> <span class="nf">train_ticket</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">):</span> <span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span> <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_loader</span><span class="p">):</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">target</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span> <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="k">if</span> <span class="n">batch_idx</span> <span class="o">%</span> <span class="n">args</span><span class="o">.</span><span class="n">log_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Train Epoch: </span><span class="si">{}</span><span class="s1"> [</span><span class="si">{}</span><span class="s1">/</span><span class="si">{}</span><span class="s1"> (</span><span class="si">{:.0f}</span><span class="s1">%)]</span><span class="se">\t</span><span class="s1">Loss: </span><span class="si">{:.6f}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">batch_idx</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">),</span> <span class="mf">100.</span> <span class="o">*</span> <span class="n">batch_idx</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="p">),</span> <span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()))</span> <span class="k">def</span> <span class="nf">test_ticket</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">):</span> <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span> <span class="n">test_loss</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">correct</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span> <span class="k">for</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="ow">in</span> <span class="n">test_loader</span><span class="p">:</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">target</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="n">test_loss</span> <span class="o">+=</span> <span class="n">F</span><span class="o">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;sum&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="c1"># sum up batch loss</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># get the index of the max log-probability</span> <span class="n">correct</span> <span class="o">+=</span> <span class="n">pred</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">target</span><span class="o">.</span><span class="n">view_as</span><span class="p">(</span><span class="n">pred</span><span class="p">))</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="n">test_loss</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">Test set: Average loss: </span><span class="si">{:.4f}</span><span class="s1">, Accuracy: </span><span class="si">{}</span><span class="s1">/</span><span class="si">{}</span><span class="s1"> (</span><span class="si">{:.0f}</span><span class="s1">%)</span><span class="se">\n</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="n">test_loss</span><span class="p">,</span> <span class="n">correct</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">),</span> <span class="mf">100.</span> <span class="o">*</span> <span class="n">correct</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">)))</span> <span class="k">def</span> <span class="nf">ticket_finder</span><span class="p">():</span> <span class="n">filename</span> <span class="o">=</span><span class="s2">&quot;ticket_finder.pt&quot;</span> <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">filename</span><span class="p">):</span> <span class="k">return</span> <span class="c1"># Training settings</span> <span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">(</span><span class="n">description</span><span class="o">=</span><span class="s1">&#39;PyTorch MNIST Example&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--batch-size&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;input batch size for training (default: 64)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--test-batch-size&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;input batch size for testing (default: 1000)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--epochs&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;number of epochs to train (default: 14)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--lr&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;LR&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;learning rate (default: 1.0)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--gamma&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mf">0.7</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;M&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;Learning rate step gamma (default: 0.7)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--no-cuda&#39;</span><span class="p">,</span> <span class="n">action</span><span class="o">=</span><span class="s1">&#39;store_true&#39;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;disables CUDA training&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--seed&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;S&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;random seed (default: 1)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--log-interval&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;how many batches to wait before logging training status&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--save-model&#39;</span><span class="p">,</span> <span class="n">action</span><span class="o">=</span><span class="s1">&#39;store_true&#39;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;For Saving the current Model&#39;</span><span class="p">)</span> <span class="n">args</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">parse_args</span><span class="p">()</span> <span class="n">use_cuda</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">args</span><span class="o">.</span><span class="n">no_cuda</span> <span class="ow">and</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">args</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cuda&quot;</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="s2">&quot;cpu&quot;</span><span class="p">)</span> <span class="n">kwargs</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;num_workers&#39;</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s1">&#39;pin_memory&#39;</span><span class="p">:</span> <span class="kc">True</span><span class="p">}</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="p">{}</span> <span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">expanduser</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">),</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">test_batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">TicketFinder</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">args</span><span class="o">.</span><span class="n">epochs</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span> <span class="n">train_ticket</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span> <span class="n">test_ticket</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">)</span> <span class="k">if</span> <span class="n">args</span><span class="o">.</span><span class="n">save_model</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">filename</span><span class="p">)</span> </pre></div> </div> </div> </div> </div> </div>Running a docker with GPU enabled (for pytorch and tensorflow)2020-03-04T00:00:00+01:002020-03-04T00:00:00+01:00http://localhost:4000/narsil.github.io/ml/docker/2020/03/04/running-gpu-enabled-docker<p>Sometimes if you want to contain dependencies you might want to use docker to containerize your projects. You can also use it for GPU In order to run docker images with GPU enabled, you are going to need:</p> <h1 id="install-docker">Install docker</h1> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>sudo apt-get install \ apt-transport-https \ ca-certificates \ curl \ gnupg-agent \ software-properties-common curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - sudo add-apt-repository \ "deb [arch=amd64] https://download.docker.com/linux/ubuntu \ $(lsb_release -cs) \ stable" sudo apt-get update sudo apt-get install docker-ce docker-ce-cli containerd.io </code></pre></div></div> <p><a href="https://docs.docker.com/install/linux/docker-ce/ubuntu/">source</a></p> <h1 id="install-nvidia-container-toolkit">Install nvidia-container-toolkit</h1> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code># Add the package repositories distribution=$(. /etc/os-release;echo $ID$VERSION_ID) curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list sudo apt-get update &amp;&amp; sudo apt-get install -y nvidia-container-toolkit sudo systemctl restart docker </code></pre></div></div> <p><a href="https://github.com/NVIDIA/nvidia-docker">source</a></p> <h1 id="launch-the-docker-for-pytorch">Launch the docker for PyTorch</h1> <p>In order to use cuda you need a nvidia enabled image, that will make everything simpler. You could of course link your own cuda library via volume mounting but it’s cumbersome (and I didn’t check that it works)</p> <ol> <li>Create an account on <a href="https://ngc.nvidia.com/">https://ngc.nvidia.com/</a></li> <li>Go to the create an API key page <a href="https://ngc.nvidia.com/setup/api-key">https://ngc.nvidia.com/setup/api-key</a></li> <li>Generate the key and copy it</li> </ol> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>docker login nvcr.io Username: $oauthtoken Password: &lt;Your Key&gt; docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:20.02-py3 bash python -c "import torch; print(torch.cuda.is_available())" # True </code></pre></div></div> <p>If you fail to login the <code class="language-plaintext highlighter-rouge">docker run</code> command will fail with <code class="language-plaintext highlighter-rouge">unauthenticated</code> error.</p> <p>Caveat: This is the only option for now, docker-compose <em>CANNOT</em> run the –gpu option. To check updates for docker compose, look at this <a href="https://github.com/docker/compose/issues/6691">issue</a></p> <p>Bonus: Nvidia put up <em>a lot</em> of containers with various libraries enabled check it out in their <a href="https://ngc.nvidia.com/catalog/">catalog</a></p> <h2 id="enjoy-">Enjoy !</h2>nicolasSometimes if you want to contain dependencies you might want to use docker to containerize your projects. You can also use it for GPU In order to run docker images with GPU enabled, you are going to need:Self KL-divergence for detecting out of distribution data and unsupervised text classification2020-02-26T00:00:00+01:002020-02-26T00:00:00+01:00http://localhost:4000/narsil.github.io/ml/nlp/kldivergence/2020/02/26/self-kl-models<!-- ################################################# ### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ### ################################################# # file to edit: _notebooks/2020-02-26-self-kl-models.ipynb --> <div class="container" id="notebook-container"> <div class="cell border-box-sizing code_cell rendered"> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <blockquote><p>TL;DR. By training two models in the same dataset order with same architecture, same loss, but different initialization, I was able to obtain a consistent out-of-distribution detector by measuring the kl-divergence between model outputs. This out-of-distribution measure used on text could lead to unsupervised text classification.</p> </blockquote> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h2 id="What's-the-problem-?">What's the problem ?<a class="anchor-link" href="#What's-the-problem-?"> </a></h2><p>ML models usually are not really capable of predicting how well the data you<br /> feed them is close to what was in the dataset. It really matters in production models as they might make really stupid mistakes just because they are off<br /> the training set.</p> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>Let's train a simple mnist model (straight out from pytorch tutorial <a href="https://github.com/pytorch/examples/tree/master/mnist">https://github.com/pytorch/examples/tree/master/mnist</a>)</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <details class="description"> <summary class="btn btn-sm" data-open="Hide Code" data-close="Show Code"></summary> <p><div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="c1">#collapse</span> <span class="kn">from</span> <span class="nn">__future__</span> <span class="kn">import</span> <span class="n">print_function</span> <span class="kn">import</span> <span class="nn">argparse</span> <span class="kn">import</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span> <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span> <span class="kn">import</span> <span class="nn">torch.optim</span> <span class="k">as</span> <span class="nn">optim</span> <span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">datasets</span><span class="p">,</span> <span class="n">transforms</span> <span class="kn">from</span> <span class="nn">torch.optim.lr_scheduler</span> <span class="kn">import</span> <span class="n">StepLR</span> <span class="kn">import</span> <span class="nn">os</span> <span class="k">class</span> <span class="nc">Net</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="nb">super</span><span class="p">(</span><span class="n">Net</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout2d</span><span class="p">(</span><span class="mf">0.25</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout2d</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">9216</span><span class="p">,</span> <span class="mi">128</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">max_pool2d</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">output</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">log_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="k">return</span> <span class="n">output</span> <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">):</span> <span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span> <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_loader</span><span class="p">):</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">target</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span> <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="k">if</span> <span class="n">batch_idx</span> <span class="o">%</span> <span class="n">args</span><span class="o">.</span><span class="n">log_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Train Epoch: </span><span class="si">{}</span><span class="s1"> [</span><span class="si">{}</span><span class="s1">/</span><span class="si">{}</span><span class="s1"> (</span><span class="si">{:.0f}</span><span class="s1">%)]</span><span class="se">\t</span><span class="s1">Loss: </span><span class="si">{:.6f}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">batch_idx</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">),</span> <span class="mf">100.</span> <span class="o">*</span> <span class="n">batch_idx</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="p">),</span> <span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()))</span> <span class="k">def</span> <span class="nf">test</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">):</span> <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span> <span class="n">test_loss</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">correct</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span> <span class="k">for</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="ow">in</span> <span class="n">test_loader</span><span class="p">:</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">target</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="n">test_loss</span> <span class="o">+=</span> <span class="n">F</span><span class="o">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;sum&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="c1"># sum up batch loss</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># get the index of the max log-probability</span> <span class="n">correct</span> <span class="o">+=</span> <span class="n">pred</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">target</span><span class="o">.</span><span class="n">view_as</span><span class="p">(</span><span class="n">pred</span><span class="p">))</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="n">test_loss</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">Test set: Average loss: </span><span class="si">{:.4f}</span><span class="s1">, Accuracy: </span><span class="si">{}</span><span class="s1">/</span><span class="si">{}</span><span class="s1"> (</span><span class="si">{:.0f}</span><span class="s1">%)</span><span class="se">\n</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="n">test_loss</span><span class="p">,</span> <span class="n">correct</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">),</span> <span class="mf">100.</span> <span class="o">*</span> <span class="n">correct</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">)))</span> <span class="k">def</span> <span class="nf">mnist</span><span class="p">():</span> <span class="n">filename</span> <span class="o">=</span> <span class="s2">&quot;mnist_cnn.pt&quot;</span> <span class="c1"># Training settings</span> <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">filename</span><span class="p">):</span> <span class="k">return</span> <span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">(</span><span class="n">description</span><span class="o">=</span><span class="s1">&#39;PyTorch MNIST Example&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--batch-size&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;input batch size for training (default: 64)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--test-batch-size&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;input batch size for testing (default: 1000)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--epochs&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">14</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;number of epochs to train (default: 14)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--lr&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;LR&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;learning rate (default: 1.0)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--gamma&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mf">0.7</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;M&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;Learning rate step gamma (default: 0.7)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--no-cuda&#39;</span><span class="p">,</span> <span class="n">action</span><span class="o">=</span><span class="s1">&#39;store_true&#39;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;disables CUDA training&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--seed&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;S&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;random seed (default: 1)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--log-interval&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;how many batches to wait before logging training status&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--save-model&#39;</span><span class="p">,</span> <span class="n">action</span><span class="o">=</span><span class="s1">&#39;store_true&#39;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;For Saving the current Model&#39;</span><span class="p">)</span> <span class="n">args</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">parse_args</span><span class="p">()</span> <span class="n">use_cuda</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">args</span><span class="o">.</span><span class="n">no_cuda</span> <span class="ow">and</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">args</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cuda&quot;</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="s2">&quot;cpu&quot;</span><span class="p">)</span> <span class="n">kwargs</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;num_workers&#39;</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s1">&#39;pin_memory&#39;</span><span class="p">:</span> <span class="kc">True</span><span class="p">}</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="p">{}</span> <span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">expanduser</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">),</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">test_batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">Adadelta</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">lr</span><span class="p">)</span> <span class="n">scheduler</span> <span class="o">=</span> <span class="n">StepLR</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">step_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">gamma</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">gamma</span><span class="p">)</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">args</span><span class="o">.</span><span class="n">epochs</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span> <span class="n">train</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span> <span class="n">test</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">)</span> <span class="n">scheduler</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="k">if</span> <span class="n">args</span><span class="o">.</span><span class="n">save_model</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">filename</span><span class="p">)</span> <span class="c1"># mnist()</span> </pre></div> </div> </div> </div> </p> </details> </div> <div class="cell border-box-sizing code_cell rendered"> <details class="description"> <summary class="btn btn-sm" data-open="Hide Code" data-close="Show Code"></summary> <p><div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="c1">#collapse</span> <span class="kn">from</span> <span class="nn">torch.distributions</span> <span class="kn">import</span> <span class="n">Categorical</span> <span class="kn">from</span> <span class="nn">torch.nn.parameter</span> <span class="kn">import</span> <span class="n">Parameter</span> <span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">transforms</span> <span class="k">def</span> <span class="nf">attack_simple</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span> <span class="n">dummy_input</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span> <span class="n">lr</span> <span class="o">=</span> <span class="mi">1</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">Adadelta</span><span class="p">([</span><span class="n">dummy_input</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span> <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">dummy_input</span><span class="p">)</span> <span class="n">entropy</span> <span class="o">=</span> <span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span> <span class="o">=</span> <span class="n">output</span><span class="p">)</span><span class="o">.</span><span class="n">entropy</span><span class="p">()</span> <span class="c1"># print(f&#39;Entropy {entropy.item():.2f}&#39;)</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span> <span class="n">entropy</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="n">MAX</span> <span class="o">=</span> <span class="n">output</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">()</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">pil_img</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">((</span><span class="mi">240</span><span class="p">,</span> <span class="mi">240</span><span class="p">))(</span><span class="n">transforms</span><span class="o">.</span><span class="n">ToPILImage</span><span class="p">()(</span><span class="n">dummy_input</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span> <span class="k">return</span> <span class="p">(</span><span class="n">MAX</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mf">0.8</span><span class="p">,</span> <span class="n">MAX</span><span class="p">,</span> <span class="n">pil_img</span><span class="p">)</span> <span class="k">def</span> <span class="nf">check_attack</span><span class="p">():</span> <span class="n">mnist_model</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span> <span class="n">mnist_model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s1">&#39;mnist_cnn.pt&#39;</span><span class="p">))</span> <span class="n">success</span><span class="p">,</span> <span class="n">MAX</span><span class="p">,</span> <span class="n">pil_img</span> <span class="o">=</span> <span class="n">attack_simple</span><span class="p">(</span><span class="n">mnist_model</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;MNIST Model says : This is a </span><span class="si">{</span><span class="n">MAX</span><span class="o">.</span><span class="n">indices</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">}</span><span class="s2"> with probability </span><span class="si">{</span><span class="n">MAX</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">*</span> <span class="mi">100</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">%&quot;</span><span class="p">)</span> <span class="n">display</span><span class="p">(</span><span class="n">pil_img</span><span class="p">)</span> <span class="n">success_rate</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">attack_simple</span><span class="p">(</span><span class="n">mnist_model</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">))</span> <span class="o">/</span> <span class="mf">100.</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Success rate </span><span class="si">{</span><span class="n">success_rate</span> <span class="o">*</span> <span class="mi">100</span><span class="si">:</span><span class="s2"> .2f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="c1"># check_attack()</span> </pre></div> </div> </div> </div> </p> </details> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>Then generate an random image for which the model is highly confident yet it's completely absurd. This new image is out of distribution yet the model does not know it. We want to avoid doing such mistakes in production.</p> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h2 id="Other-approaches">Other approaches<a class="anchor-link" href="#Other-approaches"> </a></h2> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>Other out of distribution detector have been proposed. Here is a sample of methods:</p> <ul> <li><a href="https://arxiv.org/pdf/1906.02845.pdf">Likelihood Ratios for Out-of-Distribution Detection</a>: Propose to learn 2 distinct models, one "raw", one with perturbation instilled into the dataset, and look at the log likelihood ratio of the two models, claim is that the difference between the two will reflect how "far" input is from the semantic part of the manifold of X. $p(x) = p(x_{background})p(x_{semantic})$, the perturbation needs to lie only on $x_{semantic}$.</li> <li><a href="https://arxiv.org/pdf/1910.04241.pdf">Out-of-distribution Detection in Classifiers via Generation</a>: Propose to use autoencoder (or GANs) to generate a low dimensional representation of the manifold of the dataset X, then perturb X on that representation. Those perturbated examples are trained to become a new "class" of the output of the classifier. </li> <li><a href="https://arxiv.org/pdf/1706.02690.pdf">Enhancing the reliability of Out-of-Distribution Image Detection in Neural Networks (Odin)</a>: This one uses temperature scaling regarding softmax to generate perturbated input, then look at the probability of the softmax if it passes a threshold. IMO, this paper is interesting as it supposes smoothness properties on In distribution data, and less smooth for out-of-distribution. It does require some examples of out-of-distribution for fitting 3 hyperparameters (temperature, threshold and magnitude of perturbation)</li> <li><p><a href="https://openreview.net/pdf?id=Hkxzx0NtDB">Your classifier is secretly an energy based model and you should treat it like one</a>: This one adds a new term in the loss to estimate p(x) basically. Multiple ood detectors are proposed, the most efficient being the second derivative of p(x), claiming again that density of p(x) will change more widly in ood space, leading to a good ood detector.</p> </li> <li><p><a href="https://arxiv.org/pdf/1810.01392.pdf">WAIC, but Why? Generative Ensembles for Robust Anomaly Detection</a>: This paper proposes to use an ensemble of models and look at WAIC criterion to detect OOD. It makes many comparison to VAE and GANs</p> </li> <li><p><a href="https://arxiv.org/pdf/1802.04865v1.pdf">Learning Confidence for Out-of-Distribution Detection in Neural Networks</a> : The core idea in this paper is to change the learning loss, to learn confidence as prior task to classification task, a model is allowed to see real label only when it claims it can solve the problem, outputting via another head directly a confidence score. Caveat is that the model might choose to give up and always claim confidence, and another trick is proposed to emphasize the in-distribution vs out-of-distribution by preprocessing inputs to move them towards region of higher confidence. In-distribution tends to move closer to 1 than out-of-distribution. So the direct confidence estimator seems to be <em>smoother</em> out-of-distribution than in-distribution, where peaks are more likely to be found.</p> </li> <li><p><a href="https://paperswithcode.com/task/out-of-distribution-detection">Papers with code</a>: More links on that hopefully</p> </li> </ul> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h2 id="Our-approach">Our approach<a class="anchor-link" href="#Our-approach"> </a></h2> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <blockquote><p><strong>Tl;dr : Make two similar models, with two different random initialization, then train them at the same time.&gt; The ood detector will simply be the a threshold classifier on the KL-divergence between the two outputs.</strong> The core argument for this approach is that the neural network captures the dataset manifold (which means it will produce "regular" outputs for in dataset items). For the range of possible values it has random values for a random initialization. If that is true, then we train the model, we shift it's output only on the dataset manifold, and not anywhere else. If that assumption is correct, then the 2 models have very low probability of concurring in their output outside of the manifold if they have been initialized differently.</p> </blockquote> <p>It's quite close to WAIC, <em>but</em> the two models need to be trained at the same time. The argument is that is should align gradients during the training phase, leading to more correlation for in-dataset prediction for the models. The argument for this supposes that the lottery ticket hypothesis is true, and adds that lottery ticket is unique (or at least that the class of lottery tickets is very thin, and they all highly correlate to each other). If this is true, then the gradients within the network that correspond to this lottery ticket winner in <em>both</em> networks should be the same (or highly correlated).</p> <p>In order to fix the threshold, we found that simply setting it to be 10x the average kl-divergence obtained on the train dataset worked pretty well. As kl divergence is measured in bits, 10x is a quite large margin. More work could be done to study more closely the behaviour of this self kl-divergence.</p> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h2 id="Experiments">Experiments<a class="anchor-link" href="#Experiments"> </a></h2> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h3 id="Experiment-1">Experiment 1<a class="anchor-link" href="#Experiment-1"> </a></h3><p>MNIST attack like failure presented before.</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="k">class</span> <span class="nc">MultiNet</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">models</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">models</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">(</span><span class="n">models</span><span class="p">)</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="p">[</span><span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">model</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">models</span><span class="p">]</span> <span class="k">def</span> <span class="nf">train_multi</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">):</span> <span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span> <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_loader</span><span class="p">):</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">target</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">F</span><span class="o">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="k">for</span> <span class="n">output</span> <span class="ow">in</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="k">if</span> <span class="n">batch_idx</span> <span class="o">%</span> <span class="n">args</span><span class="o">.</span><span class="n">log_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Train Epoch: </span><span class="si">{}</span><span class="s1"> [</span><span class="si">{}</span><span class="s1">/</span><span class="si">{}</span><span class="s1"> (</span><span class="si">{:.0f}</span><span class="s1">%)]</span><span class="se">\t</span><span class="s1">Loss: </span><span class="si">{:.6f}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">batch_idx</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">),</span> <span class="mf">100.</span> <span class="o">*</span> <span class="n">batch_idx</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="p">),</span> <span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()))</span> <span class="k">def</span> <span class="nf">test_multi</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">):</span> <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span> <span class="n">test_loss</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">correct</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span> <span class="k">for</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="ow">in</span> <span class="n">test_loader</span><span class="p">:</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">target</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="n">test_loss</span> <span class="o">+=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">F</span><span class="o">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;sum&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="k">for</span> <span class="n">output</span> <span class="ow">in</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># get the index of the max log-probability</span> <span class="n">correct</span> <span class="o">+=</span> <span class="n">pred</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">target</span><span class="o">.</span><span class="n">view_as</span><span class="p">(</span><span class="n">pred</span><span class="p">))</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="n">test_loss</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">Test set: Average loss: </span><span class="si">{:.4f}</span><span class="s1">, Accuracy: </span><span class="si">{}</span><span class="s1">/</span><span class="si">{}</span><span class="s1"> (</span><span class="si">{:.0f}</span><span class="s1">%)</span><span class="se">\n</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="n">test_loss</span><span class="p">,</span> <span class="n">correct</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">),</span> <span class="mf">100.</span> <span class="o">*</span> <span class="n">correct</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">)))</span> <span class="k">def</span> <span class="nf">mnist_multi</span><span class="p">():</span> <span class="c1"># Training settings</span> <span class="n">filename</span> <span class="o">=</span> <span class="s2">&quot;mnist_multi_cnn.pt&quot;</span> <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">filename</span><span class="p">):</span> <span class="k">return</span> <span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">(</span><span class="n">description</span><span class="o">=</span><span class="s1">&#39;PyTorch MNIST Example&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--batch-size&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;input batch size for training (default: 64)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--test-batch-size&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;input batch size for testing (default: 1000)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--epochs&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">14</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;number of epochs to train (default: 14)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--lr&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;LR&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;learning rate (default: 1.0)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--gamma&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mf">0.7</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;M&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;Learning rate step gamma (default: 0.7)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--no-cuda&#39;</span><span class="p">,</span> <span class="n">action</span><span class="o">=</span><span class="s1">&#39;store_true&#39;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;disables CUDA training&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--seed&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;S&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;random seed (default: 1)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--log-interval&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;how many batches to wait before logging training status&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--save-model&#39;</span><span class="p">,</span> <span class="n">action</span><span class="o">=</span><span class="s1">&#39;store_true&#39;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;For Saving the current Model&#39;</span><span class="p">)</span> <span class="n">args</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">parse_args</span><span class="p">()</span> <span class="n">use_cuda</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">args</span><span class="o">.</span><span class="n">no_cuda</span> <span class="ow">and</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">args</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cuda&quot;</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="s2">&quot;cpu&quot;</span><span class="p">)</span> <span class="n">kwargs</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;num_workers&#39;</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s1">&#39;pin_memory&#39;</span><span class="p">:</span> <span class="kc">True</span><span class="p">}</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="p">{}</span> <span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">expanduser</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">),</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">test_batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="n">model1</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span> <span class="n">model2</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span> <span class="n">model</span> <span class="o">=</span> <span class="n">MultiNet</span><span class="p">(</span><span class="n">model1</span><span class="p">,</span> <span class="n">model2</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">Adadelta</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">lr</span><span class="p">)</span> <span class="n">scheduler</span> <span class="o">=</span> <span class="n">StepLR</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">step_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">gamma</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">gamma</span><span class="p">)</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">args</span><span class="o">.</span><span class="n">epochs</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span> <span class="n">train_multi</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span> <span class="n">test_multi</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">)</span> <span class="n">scheduler</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="k">if</span> <span class="n">args</span><span class="o">.</span><span class="n">save_model</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">filename</span><span class="p">)</span> <span class="c1"># mnist_multi()</span> </pre></div> </div> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">datasets</span> <span class="k">def</span> <span class="nf">kl</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">):</span> <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span> <span class="n">test_loss</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span> <span class="k">for</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="ow">in</span> <span class="n">test_loader</span><span class="p">:</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">target</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">n</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">outputs</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">outputs</span><span class="p">)):</span> <span class="n">n</span> <span class="o">+=</span> <span class="mi">1</span> <span class="n">loss</span> <span class="o">+=</span> <span class="mi">1</span><span class="o">/</span><span class="mi">2</span> <span class="o">*</span> <span class="p">(</span><span class="n">F</span><span class="o">.</span><span class="n">kl_div</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="n">j</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;sum&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">+</span> <span class="n">F</span><span class="o">.</span><span class="n">kl_div</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="n">j</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;sum&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">())</span> <span class="n">loss</span> <span class="o">/=</span> <span class="n">n</span> <span class="n">test_loss</span> <span class="o">+=</span> <span class="n">loss</span> <span class="n">test_loss</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">Test set: Average loss: </span><span class="si">{:.4f}</span><span class="s1">, len </span><span class="si">{}</span><span class="s1"> </span><span class="se">\n</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="n">test_loss</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">)))</span> <span class="k">return</span> <span class="n">test_loss</span> <span class="k">def</span> <span class="nf">get_reference_kl</span><span class="p">():</span> <span class="n">multi_model</span> <span class="o">=</span> <span class="n">MultiNet</span><span class="p">(</span><span class="n">Net</span><span class="p">(),</span> <span class="n">Net</span><span class="p">())</span> <span class="n">multi_model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s1">&#39;mnist_multi_cnn.pt&#39;</span><span class="p">))</span> <span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">expanduser</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">),</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])),</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">return</span> <span class="n">kl</span><span class="p">(</span><span class="n">multi_model</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">,</span> <span class="n">test_loader</span><span class="o">=</span><span class="n">test_loader</span><span class="p">)</span> <span class="c1"># ref_kl_loss = get_reference_kl()</span> </pre></div> </div> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>Now we have 2 models capable of detecting digits, we have instantly 3 checks for checking if the output of our model is valid. The 2 models need to be concording (they need to outputs the same digit as an output), they need to have similar kl-divergence, we actually have a reference for the test set, so we know what kind of divergence we should look for, anything 10x more is definitely ood (we could look at the test set distribution for more fine grain estimation). Because kl divergence is asymetric we have 2 values (it's harder for spiked distribution to have another distribution be close in the kl sense, so taking the max of kl-divergence should be used for out-of-distribution.</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <details class="description"> <summary class="btn btn-sm" data-open="Hide Code" data-close="Show Code"></summary> <p><div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="c1">#collapse</span> <span class="kn">from</span> <span class="nn">torch.distributions</span> <span class="kn">import</span> <span class="n">Categorical</span> <span class="kn">from</span> <span class="nn">torch.nn.parameter</span> <span class="kn">import</span> <span class="n">Parameter</span> <span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">transforms</span> <span class="k">def</span> <span class="nf">attack</span><span class="p">(</span><span class="n">loss_fn</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">n</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span> <span class="n">multi_model</span> <span class="o">=</span> <span class="n">MultiNet</span><span class="p">(</span><span class="n">Net</span><span class="p">(),</span> <span class="n">Net</span><span class="p">())</span> <span class="n">multi_model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s1">&#39;mnist_multi_cnn.pt&#39;</span><span class="p">))</span> <span class="n">dummy_input</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">Adadelta</span><span class="p">([</span><span class="n">dummy_input</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">):</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">multi_model</span><span class="p">(</span><span class="n">dummy_input</span><span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">outputs</span><span class="p">)</span> <span class="c1"># print(f&#39;Entropy {entropy.item():.2f}&#39;)</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="n">MAX1</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">()</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">MAX2</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">()</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">kl_loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">kl_div</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;batchmean&#39;</span><span class="p">)</span> <span class="n">kl_loss2</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">kl_div</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;batchmean&#39;</span><span class="p">)</span> <span class="k">if</span> <span class="p">(</span><span class="n">kl_loss</span> <span class="o">/</span> <span class="n">ref_kl_loss</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">10</span> <span class="ow">or</span> <span class="n">kl_loss2</span> <span class="o">/</span> <span class="n">ref_kl_loss</span> <span class="o">&gt;</span> <span class="mi">10</span> <span class="ow">or</span> <span class="n">MAX1</span><span class="o">.</span><span class="n">indices</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">!=</span> <span class="n">MAX2</span><span class="o">.</span><span class="n">indices</span><span class="o">.</span><span class="n">item</span><span class="p">():</span> <span class="n">success</span> <span class="o">=</span> <span class="kc">False</span> <span class="k">else</span><span class="p">:</span> <span class="n">success</span> <span class="o">=</span> <span class="n">MAX1</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mf">0.8</span> <span class="ow">and</span> <span class="n">MAX2</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mf">0.8</span> <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;MNIST Model says : This is a </span><span class="si">{</span><span class="n">MAX1</span><span class="o">.</span><span class="n">indices</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">}</span><span class="s2"> with probability </span><span class="si">{</span><span class="n">MAX1</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">*</span> <span class="mi">100</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">%&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;MNIST Model 2 says : This is a </span><span class="si">{</span><span class="n">MAX2</span><span class="o">.</span><span class="n">indices</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">}</span><span class="s2"> with probability </span><span class="si">{</span><span class="n">MAX2</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">*</span> <span class="mi">100</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">%&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;KL-divergence is </span><span class="si">{</span><span class="n">kl_loss</span> <span class="o">/</span> <span class="n">ref_kl_loss</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">kl_loss2</span> <span class="o">/</span> <span class="n">ref_kl_loss</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="k">if</span> <span class="n">success</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;ATTACK SUCCEEDED&quot;</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;ATTACK FAILED&quot;</span><span class="p">)</span> <span class="n">pil_img</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">((</span><span class="mi">240</span><span class="p">,</span> <span class="mi">240</span><span class="p">))(</span><span class="n">transforms</span><span class="o">.</span><span class="n">ToPILImage</span><span class="p">()(</span><span class="n">dummy_input</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span> <span class="n">display</span><span class="p">(</span><span class="n">pil_img</span><span class="p">)</span> <span class="k">return</span> <span class="n">success</span> </pre></div> </div> </div> </div> </p> </details> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>Now if we simply attack the first model like we did earlier, we can see that we can trick it as easily as before. <em>BUT</em> the second model, does not get attacked which is to be expected.</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="k">def</span> <span class="nf">loss</span><span class="p">(</span><span class="n">outputs</span><span class="p">):</span> <span class="n">entropy</span> <span class="o">=</span> <span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">entropy</span><span class="p">()</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">entropy</span> <span class="k">return</span> <span class="n">loss</span> <span class="n">_</span> <span class="o">=</span> <span class="n">attack</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> </pre></div> </div> </div> </div> <div class="output_wrapper"> <div class="output"> <div class="output_area"> <div class="output_subarea output_stream output_stdout output_text"> <pre>MNIST Model says : This is a 3 with probability 99.32% MNIST Model 2 says : This is a 3 with probability 33.50% KL-divergence is 587.7392578125 152.96902465820312 ATTACK FAILED </pre> </div> </div> <div class="output_area"> <div class="output_png output_subarea "> <img src=" " /> </div> </div> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>Even if we try a smarter and attack <strong>both</strong> models at the same time, we can't succeed at a consistent rate. Be warned, it will succeed sometimes, just not consistently.</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="k">def</span> <span class="nf">loss</span><span class="p">(</span><span class="n">outputs</span><span class="p">):</span> <span class="n">entropy1</span> <span class="o">=</span> <span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">entropy</span><span class="p">()</span> <span class="n">entropy2</span> <span class="o">=</span> <span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">entropy</span><span class="p">()</span> <span class="n">kl_loss1</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">kl_div</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;batchmean&#39;</span><span class="p">)</span> <span class="n">kl_loss2</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">kl_div</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;batchmean&#39;</span><span class="p">)</span> <span class="n">distance</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">mse_loss</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">entropy1</span> <span class="o">+</span> <span class="n">entropy2</span> <span class="o">+</span> <span class="n">kl_loss1</span> <span class="o">+</span> <span class="n">kl_loss2</span> <span class="o">+</span> <span class="n">distance</span> <span class="k">return</span> <span class="n">loss</span> <span class="n">_</span> <span class="o">=</span> <span class="n">attack</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> </pre></div> </div> </div> </div> <div class="output_wrapper"> <div class="output"> <div class="output_area"> <div class="output_subarea output_stream output_stdout output_text"> <pre>MNIST Model says : This is a 1 with probability 11.50% MNIST Model 2 says : This is a 7 with probability 11.48% KL-divergence is 0.474844753742218 0.47643253207206726 ATTACK FAILED </pre> </div> </div> <div class="output_area"> <div class="output_png output_subarea "> <img src=" " /> </div> </div> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>Be warned, it will succeed sometimes, just not consistently. For comparison, the first attack succeeds with close to 100% (we couldn't make it fail). Actually because we have 10 classes, and if we supposed out-of-distribution probability distribution is uniformly random, it should be something close to 10%, when our initial random image finds a place where the 2 models intersect on the same digit.</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="k">def</span> <span class="nf">loss</span><span class="p">(</span><span class="n">outputs</span><span class="p">):</span> <span class="n">entropy1</span> <span class="o">=</span> <span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">entropy</span><span class="p">()</span> <span class="n">entropy2</span> <span class="o">=</span> <span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">entropy</span><span class="p">()</span> <span class="n">kl_loss1</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">kl_div</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;batchmean&#39;</span><span class="p">)</span> <span class="n">kl_loss2</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">kl_div</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;batchmean&#39;</span><span class="p">)</span> <span class="n">distance</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">mse_loss</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">entropy1</span> <span class="o">+</span> <span class="n">entropy2</span> <span class="o">+</span> <span class="n">kl_loss1</span> <span class="o">+</span> <span class="n">kl_loss2</span> <span class="o">+</span> <span class="n">distance</span> <span class="k">return</span> <span class="n">loss</span> <span class="k">def</span> <span class="nf">attack_rate</span><span class="p">():</span> <span class="n">attacks</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span> <span class="n">success</span> <span class="o">=</span> <span class="n">attack</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">n</span><span class="o">=</span><span class="mi">200</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span> <span class="k">if</span> <span class="n">success</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;F&quot;</span><span class="p">,</span> <span class="n">end</span><span class="o">=</span><span class="s1">&#39;&#39;</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">,</span> <span class="n">end</span><span class="o">=</span><span class="s1">&#39;&#39;</span><span class="p">)</span> <span class="n">attacks</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">success</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;&#39;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Attack success rate </span><span class="si">{</span><span class="nb">sum</span><span class="p">(</span><span class="n">attacks</span><span class="p">)</span><span class="o">/</span><span class="nb">len</span><span class="p">(</span><span class="n">attacks</span><span class="p">)</span> <span class="o">*</span> <span class="mi">100</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">%&quot;</span><span class="p">)</span> <span class="c1"># attack_rate()</span> </pre></div> </div> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>The actual attack range seems to stagnate at around 0% (30% if we remove the confidence rate &gt; 80%) with various learning rates and attack steps. There probably are better strategies to attack, this, but the main point is that it became <strong>harder</strong>.</p> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h3 id="Experiment-2">Experiment 2<a class="anchor-link" href="#Experiment-2"> </a></h3><p>Now let's test this on common ood detection for classic datasets. We will add ood detection for the train dataset, just to check that we don't <em>exclude</em> too much of the original dataset. Datasets used will be MNIST, FashionMNIST</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <details class="description"> <summary class="btn btn-sm" data-open="Hide Code" data-close="Show Code"></summary> <p><div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="c1">#collapse</span> <span class="kn">from</span> <span class="nn">torchvision.datasets</span> <span class="kn">import</span> <span class="n">MNIST</span><span class="p">,</span> <span class="n">Omniglot</span><span class="p">,</span> <span class="n">FashionMNIST</span> <span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">transforms</span> <span class="kn">import</span> <span class="nn">os</span> <span class="k">def</span> <span class="nf">dataset_multi</span><span class="p">(</span><span class="n">dataset_cls</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">transform</span><span class="p">):</span> <span class="c1"># Training settings</span> <span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">(</span><span class="n">description</span><span class="o">=</span><span class="s1">&#39;PyTorch MNIST Example&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--batch-size&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;input batch size for training (default: 64)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--test-batch-size&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;input batch size for testing (default: 1000)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--epochs&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">14</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;number of epochs to train (default: 14)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--lr&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;LR&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;learning rate (default: 1.0)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--gamma&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mf">0.7</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;M&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;Learning rate step gamma (default: 0.7)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--no-cuda&#39;</span><span class="p">,</span> <span class="n">action</span><span class="o">=</span><span class="s1">&#39;store_true&#39;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;disables CUDA training&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--seed&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;S&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;random seed (default: 1)&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--log-interval&#39;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;how many batches to wait before logging training status&#39;</span><span class="p">)</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">&#39;--save-model&#39;</span><span class="p">,</span> <span class="n">action</span><span class="o">=</span><span class="s1">&#39;store_true&#39;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s1">&#39;For Saving the current Model&#39;</span><span class="p">)</span> <span class="n">args</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">parse_args</span><span class="p">()</span> <span class="n">use_cuda</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">args</span><span class="o">.</span><span class="n">no_cuda</span> <span class="ow">and</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">args</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cuda&quot;</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="s2">&quot;cpu&quot;</span><span class="p">)</span> <span class="n">kwargs</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;num_workers&#39;</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s1">&#39;pin_memory&#39;</span><span class="p">:</span> <span class="kc">True</span><span class="p">}</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="p">{}</span> <span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">dataset_cls</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transform</span><span class="p">),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">dataset_cls</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transform</span><span class="p">),</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">test_batch_size</span><span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">lr</span><span class="p">)</span> <span class="n">scheduler</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">CyclicLR</span><span class="p">(</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">base_lr</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">max_lr</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">lr</span><span class="p">,</span> <span class="n">cycle_momentum</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">step_size_up</span><span class="o">=</span><span class="mi">200</span> <span class="p">)</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">args</span><span class="o">.</span><span class="n">epochs</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span> <span class="n">train_multi</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span> <span class="n">test_multi</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">)</span> <span class="n">scheduler</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="k">if</span> <span class="n">args</span><span class="o">.</span><span class="n">save_model</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">filename</span><span class="p">)</span> <span class="k">def</span> <span class="nf">run_datasets</span><span class="p">(</span><span class="n">create_model</span><span class="p">,</span> <span class="n">suffix</span><span class="p">):</span> <span class="n">datasets</span> <span class="o">=</span> <span class="p">[</span><span class="n">MNIST</span><span class="p">,</span> <span class="n">FashionMNIST</span><span class="p">]</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cuda&quot;</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s2">&quot;cpu&quot;</span><span class="p">)</span> <span class="k">for</span> <span class="n">dataset_cls</span> <span class="ow">in</span> <span class="n">datasets</span><span class="p">:</span> <span class="n">filename</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">dataset_cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}{</span><span class="n">suffix</span><span class="si">}</span><span class="s1">.pt&#39;</span> <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">filename</span><span class="p">):</span> <span class="k">continue</span> <span class="n">model</span> <span class="o">=</span> <span class="n">create_model</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">transform</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])</span> <span class="n">dataset_multi</span><span class="p">(</span><span class="n">dataset_cls</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">transform</span><span class="p">)</span> <span class="k">def</span> <span class="nf">create_model</span><span class="p">():</span> <span class="n">model1</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span> <span class="n">model2</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span> <span class="n">model</span> <span class="o">=</span> <span class="n">MultiNet</span><span class="p">(</span><span class="n">model1</span><span class="p">,</span> <span class="n">model2</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="c1"># run_datasets(create_model, suffix=&#39;&#39;)</span> </pre></div> </div> </div> </div> </p> </details> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">roc_auc_score</span> </pre></div> </div> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="k">def</span> <span class="nf">test_datasets</span><span class="p">(</span><span class="n">model_arch</span><span class="p">,</span> <span class="n">suffix</span><span class="p">):</span> <span class="n">datasets</span> <span class="o">=</span> <span class="p">[</span><span class="n">MNIST</span><span class="p">,</span> <span class="n">FashionMNIST</span><span class="p">]</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">500</span> <span class="n">device</span> <span class="o">=</span> <span class="s1">&#39;cuda&#39;</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s1">&#39;cpu&#39;</span> <span class="k">for</span> <span class="n">dataset_cls</span> <span class="ow">in</span> <span class="n">datasets</span><span class="p">:</span> <span class="n">filename</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">dataset_cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}{</span><span class="n">suffix</span><span class="si">}</span><span class="s1">.pt&#39;</span> <span class="n">model_arch</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">filename</span><span class="p">))</span> <span class="n">model</span> <span class="o">=</span> <span class="n">model_arch</span> <span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">dataset_cls</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span> <span class="n">ref_kl_loss</span> <span class="o">=</span> <span class="n">kl</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Ref loss&quot;</span><span class="p">,</span> <span class="n">ref_kl_loss</span><span class="p">)</span> <span class="n">all_labels</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">all_scores</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">dataset_cls2</span> <span class="ow">in</span> <span class="n">datasets</span><span class="p">:</span> <span class="n">test_loader2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">dataset_cls2</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])),</span><span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">OOD</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="ow">in</span> <span class="n">test_loader2</span><span class="p">:</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">))</span> <span class="n">kl_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">F</span><span class="o">.</span><span class="n">kl_div</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;none&#39;</span><span class="p">),</span> <span class="n">F</span><span class="o">.</span><span class="n">kl_div</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;none&#39;</span><span class="p">))</span> <span class="n">kl_loss</span> <span class="o">=</span> <span class="n">kl_loss</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">similar</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">normed</span> <span class="o">=</span> <span class="n">kl_loss</span> <span class="o">/</span> <span class="n">ref_kl_loss</span> <span class="n">kl_anomaly</span> <span class="o">=</span> <span class="n">normed</span> <span class="o">&gt;</span> <span class="mi">10</span> <span class="n">non_concordant</span> <span class="o">=</span> <span class="n">similar</span> <span class="o">==</span> <span class="kc">False</span> <span class="n">out_of_distrib</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">kl_anomaly</span> <span class="o">|</span> <span class="n">non_concordant</span><span class="p">)</span> <span class="n">N</span> <span class="o">=</span> <span class="n">normed</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">boolean</span> <span class="o">=</span> <span class="n">dataset_cls2</span> <span class="o">!=</span> <span class="n">dataset_cls</span> <span class="n">all_labels</span><span class="o">.</span><span class="n">extend</span><span class="p">([</span><span class="n">boolean</span><span class="p">]</span> <span class="o">*</span> <span class="n">N</span><span class="p">)</span> <span class="n">all_scores</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">normed</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span> <span class="n">OOD</span> <span class="o">+=</span> <span class="n">out_of_distrib</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Trained on </span><span class="si">{</span><span class="n">dataset_cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> we detected on </span><span class="si">{</span><span class="n">dataset_cls2</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">OOD</span><span class="si">}</span><span class="s2">/</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">test_loader2</span><span class="o">.</span><span class="n">dataset</span><span class="p">)</span><span class="si">}</span><span class="s2"> (</span><span class="si">{</span><span class="nb">float</span><span class="p">(</span><span class="n">OOD</span><span class="p">)</span><span class="o">/</span><span class="nb">len</span><span class="p">(</span><span class="n">test_loader2</span><span class="o">.</span><span class="n">dataset</span><span class="p">)</span> <span class="o">*</span> <span class="mi">100</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">%) out of distribution&quot;</span><span class="p">)</span> <span class="n">auc</span> <span class="o">=</span> <span class="n">roc_auc_score</span><span class="p">(</span><span class="n">all_labels</span><span class="p">,</span> <span class="n">all_scores</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;AUC for </span><span class="si">{</span><span class="n">dataset_cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> : </span><span class="si">{</span><span class="n">auc</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="k">def</span> <span class="nf">exp_2</span><span class="p">():</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda&#39;</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s1">&#39;cpu&#39;</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">MultiNet</span><span class="p">(</span><span class="n">Net</span><span class="p">(),</span> <span class="n">Net</span><span class="p">())</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">test_datasets</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">suffix</span><span class="o">=</span><span class="s1">&#39;&#39;</span><span class="p">)</span> <span class="c1"># exp_2()</span> </pre></div> </div> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>So we can see that we achieve, with no tuning whatsoever a decent out of distribution detector. We seem to achieve much better AUROC on MNIST, probably because the in-distribution learning seems to be much better (99% test accuracy vs 92% for fastionMNIST). So to False positives for fashionMNIST probably come from this hard to learn in-distribution. Some fine tuning needs to be done to get better results. We also have to keep in mind, that the models to learn this are quite small (2M parameters but only 2 convolution layers) so the lottery hypothesis validity for such a network might be questionned.</p> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h3 id="Experiment-2-bis">Experiment 2 bis<a class="anchor-link" href="#Experiment-2-bis"> </a></h3><p>Same experiment but with fine tuned, larger networks on the same datasets</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="kn">from</span> <span class="nn">torchvision.models.resnet</span> <span class="kn">import</span> <span class="n">ResNet</span><span class="p">,</span> <span class="n">BasicBlock</span> <span class="k">class</span> <span class="nc">MnistResNet</span><span class="p">(</span><span class="n">ResNet</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="nb">super</span><span class="p">(</span><span class="n">MnistResNet</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">BasicBlock</span><span class="p">,</span> <span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">7</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="n">F</span><span class="o">.</span><span class="n">log_softmax</span><span class="p">(</span><span class="nb">super</span><span class="p">(</span><span class="n">MnistResNet</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="k">def</span> <span class="nf">run_datasets_res</span><span class="p">():</span> <span class="n">datasets</span> <span class="o">=</span> <span class="p">[</span><span class="n">MNIST</span><span class="p">,</span> <span class="n">FashionMNIST</span><span class="p">]</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cuda&quot;</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s2">&quot;cpu&quot;</span><span class="p">)</span> <span class="k">for</span> <span class="n">dataset_cls</span> <span class="ow">in</span> <span class="n">datasets</span><span class="p">:</span> <span class="n">filename</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">dataset_cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s1">_resnet.pt&#39;</span> <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">filename</span><span class="p">):</span> <span class="k">continue</span> <span class="n">multi_res</span> <span class="o">=</span> <span class="n">MultiNet</span><span class="p">(</span><span class="n">MnistResNet</span><span class="p">(),</span> <span class="n">MnistResNet</span><span class="p">())</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">transform</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">((</span><span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">)),</span><span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))])</span> <span class="n">dataset_multi</span><span class="p">(</span><span class="n">dataset_cls</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">multi_res</span><span class="p">,</span> <span class="n">transform</span><span class="p">)</span> <span class="k">def</span> <span class="nf">test_datasets_bis</span><span class="p">(</span><span class="n">model_arch</span><span class="p">):</span> <span class="n">datasets</span> <span class="o">=</span> <span class="p">[</span><span class="n">MNIST</span><span class="p">,</span> <span class="n">FashionMNIST</span><span class="p">]</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">10</span> <span class="n">device</span> <span class="o">=</span> <span class="s1">&#39;cuda&#39;</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s1">&#39;cpu&#39;</span> <span class="k">for</span> <span class="n">dataset_cls</span> <span class="ow">in</span> <span class="n">datasets</span><span class="p">:</span> <span class="n">filename</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">dataset_cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s1">_resnet.pt&#39;</span> <span class="n">model_arch</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">filename</span><span class="p">))</span> <span class="n">model</span> <span class="o">=</span> <span class="n">model_arch</span> <span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">dataset_cls</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">((</span><span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">)),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span> <span class="n">ref_kl_loss</span> <span class="o">=</span> <span class="n">kl</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Ref loss&quot;</span><span class="p">,</span> <span class="n">ref_kl_loss</span><span class="p">)</span> <span class="n">all_labels</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">all_scores</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">dataset_cls2</span> <span class="ow">in</span> <span class="n">datasets</span><span class="p">:</span> <span class="n">test_loader2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">dataset_cls2</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">((</span><span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">)),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])),</span><span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">OOD</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="ow">in</span> <span class="n">test_loader2</span><span class="p">:</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">))</span> <span class="n">kl_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">F</span><span class="o">.</span><span class="n">kl_div</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;none&#39;</span><span class="p">),</span> <span class="n">F</span><span class="o">.</span><span class="n">kl_div</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;none&#39;</span><span class="p">))</span> <span class="n">kl_loss</span> <span class="o">=</span> <span class="n">kl_loss</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">similar</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">normed</span> <span class="o">=</span> <span class="n">kl_loss</span> <span class="o">/</span> <span class="n">ref_kl_loss</span> <span class="n">kl_anomaly</span> <span class="o">=</span> <span class="n">normed</span> <span class="o">&gt;</span> <span class="mi">10</span> <span class="n">non_concordant</span> <span class="o">=</span> <span class="n">similar</span> <span class="o">==</span> <span class="kc">False</span> <span class="n">out_of_distrib</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">kl_anomaly</span> <span class="o">|</span> <span class="n">non_concordant</span><span class="p">)</span> <span class="n">N</span> <span class="o">=</span> <span class="n">normed</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">boolean</span> <span class="o">=</span> <span class="n">dataset_cls2</span> <span class="o">!=</span> <span class="n">dataset_cls</span> <span class="n">all_labels</span><span class="o">.</span><span class="n">extend</span><span class="p">([</span><span class="n">boolean</span><span class="p">]</span> <span class="o">*</span> <span class="n">N</span><span class="p">)</span> <span class="n">all_scores</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">normed</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span> <span class="n">OOD</span> <span class="o">+=</span> <span class="n">out_of_distrib</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Trained on </span><span class="si">{</span><span class="n">dataset_cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> we detected on </span><span class="si">{</span><span class="n">dataset_cls2</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">OOD</span><span class="si">}</span><span class="s2">/</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">test_loader2</span><span class="o">.</span><span class="n">dataset</span><span class="p">)</span><span class="si">}</span><span class="s2"> (</span><span class="si">{</span><span class="nb">float</span><span class="p">(</span><span class="n">OOD</span><span class="p">)</span><span class="o">/</span><span class="nb">len</span><span class="p">(</span><span class="n">test_loader2</span><span class="o">.</span><span class="n">dataset</span><span class="p">)</span> <span class="o">*</span> <span class="mi">100</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">%) out of distribution&quot;</span><span class="p">)</span> <span class="n">auc</span> <span class="o">=</span> <span class="n">roc_auc_score</span><span class="p">(</span><span class="n">all_labels</span><span class="p">,</span> <span class="n">all_scores</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;AUC for </span><span class="si">{</span><span class="n">dataset_cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> : </span><span class="si">{</span><span class="n">auc</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="k">def</span> <span class="nf">exp_2_bis</span><span class="p">():</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda&#39;</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s1">&#39;cpu&#39;</span><span class="p">)</span> <span class="n">multi_res</span> <span class="o">=</span> <span class="n">MultiNet</span><span class="p">(</span><span class="n">MnistResNet</span><span class="p">(),</span> <span class="n">MnistResNet</span><span class="p">())</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="n">test_datasets_bis</span><span class="p">(</span><span class="n">multi_res</span><span class="p">)</span> <span class="c1"># run_datasets_res()</span> <span class="c1"># exp_2_bis()</span> </pre></div> </div> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h3 id="Experiment-3">Experiment 3<a class="anchor-link" href="#Experiment-3"> </a></h3><p>Check that two identical networks (same initalization) actually don't work. It's just a sanity check. We should obtain always kl_div = 0 no matter where we are in the input space.</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="k">def</span> <span class="nf">create_same_model</span><span class="p">():</span> <span class="n">model1</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span> <span class="n">model</span> <span class="o">=</span> <span class="n">MultiNet</span><span class="p">(</span><span class="n">model1</span><span class="p">,</span> <span class="n">model1</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="k">def</span> <span class="nf">exp_3</span><span class="p">():</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda&#39;</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s1">&#39;cpu&#39;</span><span class="p">)</span> <span class="n">run_datasets</span><span class="p">(</span><span class="n">create_same_model</span><span class="p">,</span> <span class="n">suffix</span><span class="o">=</span><span class="s1">&#39;_exp3&#39;</span><span class="p">)</span> <span class="n">test_datasets</span><span class="p">(</span><span class="n">create_same_model</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">suffix</span><span class="o">=</span><span class="s1">&#39;_exp3&#39;</span><span class="p">)</span> <span class="c1"># exp_3()</span> </pre></div> </div> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h3 id="Experiment-4">Experiment 4<a class="anchor-link" href="#Experiment-4"> </a></h3><p>Run this method with 2, 3, 4, and so on models. We should get exponential improved accuracy, if the random behavious for out-of-distribution for models is correct.</p> </div> </div> </div> <div class="cell border-box-sizing code_cell rendered"> <div class="input"> <div class="inner_cell"> <div class="input_area"> <div class=" highlight hl-ipython3"><pre><span></span><span class="k">def</span> <span class="nf">create_n_model</span><span class="p">(</span><span class="n">n</span><span class="p">):</span> <span class="n">models</span> <span class="o">=</span> <span class="p">[</span><span class="n">Net</span><span class="p">()</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">)]</span> <span class="n">model</span> <span class="o">=</span> <span class="n">MultiNet</span><span class="p">(</span><span class="o">*</span><span class="n">models</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="k">def</span> <span class="nf">test_datasets_4</span><span class="p">(</span><span class="n">model_arch</span><span class="p">,</span> <span class="n">suffix</span><span class="p">):</span> <span class="n">datasets</span> <span class="o">=</span> <span class="p">[</span><span class="n">MNIST</span><span class="p">,</span> <span class="n">FashionMNIST</span><span class="p">]</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">100</span> <span class="n">device</span> <span class="o">=</span> <span class="s1">&#39;cuda&#39;</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s1">&#39;cpu&#39;</span> <span class="k">for</span> <span class="n">dataset_cls</span> <span class="ow">in</span> <span class="n">datasets</span><span class="p">:</span> <span class="n">filename</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">dataset_cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}{</span><span class="n">suffix</span><span class="si">}</span><span class="s1">.pt&#39;</span> <span class="n">model_arch</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">filename</span><span class="p">))</span> <span class="n">model</span> <span class="o">=</span> <span class="n">model_arch</span> <span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">dataset_cls</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])),</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span> <span class="n">ref_kl_loss</span> <span class="o">=</span> <span class="n">kl</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Ref loss&quot;</span><span class="p">,</span> <span class="n">ref_kl_loss</span><span class="p">)</span> <span class="n">all_labels</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">all_scores</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">dataset_cls2</span> <span class="ow">in</span> <span class="n">datasets</span><span class="p">:</span> <span class="n">test_loader2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">dataset_cls2</span><span class="p">(</span><span class="s1">&#39;../data&#39;</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span> <span class="p">])),</span><span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">OOD</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="ow">in</span> <span class="n">test_loader2</span><span class="p">:</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">))</span> <span class="n">kl_losses</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">outputs</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">outputs</span><span class="p">)):</span> <span class="n">kl_losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">F</span><span class="o">.</span><span class="n">kl_div</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="n">j</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;none&#39;</span><span class="p">))</span> <span class="n">kl_losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">F</span><span class="o">.</span><span class="n">kl_div</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="n">j</span><span class="p">],</span> <span class="n">outputs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">exp</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;none&#39;</span><span class="p">))</span> <span class="n">kl_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">kl_losses</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">values</span> <span class="n">kl_loss</span> <span class="o">=</span> <span class="n">kl_loss</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">similar</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">normed</span> <span class="o">=</span> <span class="n">kl_loss</span> <span class="o">/</span> <span class="n">ref_kl_loss</span> <span class="n">kl_anomaly</span> <span class="o">=</span> <span class="n">normed</span> <span class="o">&gt;</span> <span class="mi">10</span> <span class="n">non_concordant</span> <span class="o">=</span> <span class="n">similar</span> <span class="o">==</span> <span class="kc">False</span> <span class="n">out_of_distrib</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">kl_anomaly</span> <span class="o">|</span> <span class="n">non_concordant</span><span class="p">)</span> <span class="n">N</span> <span class="o">=</span> <span class="n">normed</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">boolean</span> <span class="o">=</span> <span class="n">dataset_cls2</span> <span class="o">!=</span> <span class="n">dataset_cls</span> <span class="n">all_labels</span><span class="o">.</span><span class="n">extend</span><span class="p">([</span><span class="n">boolean</span><span class="p">]</span> <span class="o">*</span> <span class="n">N</span><span class="p">)</span> <span class="n">all_scores</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">normed</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span> <span class="n">OOD</span> <span class="o">+=</span> <span class="n">out_of_distrib</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Trained on </span><span class="si">{</span><span class="n">dataset_cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> we detected on </span><span class="si">{</span><span class="n">dataset_cls2</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">OOD</span><span class="si">}</span><span class="s2">/</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">test_loader2</span><span class="o">.</span><span class="n">dataset</span><span class="p">)</span><span class="si">}</span><span class="s2"> (</span><span class="si">{</span><span class="nb">float</span><span class="p">(</span><span class="n">OOD</span><span class="p">)</span><span class="o">/</span><span class="nb">len</span><span class="p">(</span><span class="n">test_loader2</span><span class="o">.</span><span class="n">dataset</span><span class="p">)</span> <span class="o">*</span> <span class="mi">100</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">%) out of distribution&quot;</span><span class="p">)</span> <span class="n">auc</span> <span class="o">=</span> <span class="n">roc_auc_score</span><span class="p">(</span><span class="n">all_labels</span><span class="p">,</span> <span class="n">all_scores</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;AUC for </span><span class="si">{</span><span class="n">dataset_cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> : </span><span class="si">{</span><span class="n">auc</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="k">def</span> <span class="nf">exp_4</span><span class="p">():</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda&#39;</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s1">&#39;cpu&#39;</span><span class="p">)</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">8</span><span class="p">]:</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">20</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;N = </span><span class="si">{</span><span class="n">n</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="n">run_datasets</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">create_n_model</span><span class="p">(</span><span class="n">n</span><span class="p">),</span> <span class="n">suffix</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;_exp4_</span><span class="si">{</span><span class="n">n</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span> <span class="n">test_datasets_4</span><span class="p">(</span><span class="n">create_n_model</span><span class="p">(</span><span class="n">n</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">suffix</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;_exp4_</span><span class="si">{</span><span class="n">n</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span> <span class="c1"># exp_4()</span> </pre></div> </div> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>Seems not to be working too great, we ARE improving AUC. Not by a strong margin, it is probably just that we are having a better approximator of our metric by ensembling.</p> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h3 id="Experiment-5">Experiment 5<a class="anchor-link" href="#Experiment-5"> </a></h3><p>Test on a larger output space, like CIFAR-100 and SVHN, to check that part of the limits are actually due to small number of output classes for MNIST/FashionMNIST Other idea is to test on Transformers. Early experiment seems to show that we can use that idea to detect different language within text with just the kl_div used as a distance.</p> <p>Found French book within english books dataset, AND english paragraphs <em>within</em> this french book. Needs some work to clean this experiment</p> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <ul> <li>Show that small network trained on a single english book enables to detect different languages or different patterns of writing (old english, irish, french, or event dictionnaries)</li> <li>The detection is super fined grained capable of detecting english within a French book.</li> </ul> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>For brevity, we won't include training code. We just trained a simple transformer (6 layers deep) on a english text and checked our metric on some other texts.</p> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p><img src="/images/copied_from_nb/images/self-kl-train-eng.png" alt="title" /> <img src="/images/copied_from_nb/images/self-kl-test-eng.png" alt="title" /> <img src="/images/copied_from_nb/images/self-kl-test-fr.png" alt="title" /></p> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h3 id="Experiment-6">Experiment 6<a class="anchor-link" href="#Experiment-6"> </a></h3><p>Need to test with various training schemes, regularization schemes (dropout, batchnorm, l2 penalization) and so on. We should find that the smoother in-distribution our models behave the more this method should work. Hopefully test accuracy <em>should</em> be a good smoothness proxy.</p> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h2 id="Limits">Limits<a class="anchor-link" href="#Limits"> </a></h2> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>The pros for this method are that:</p> <ul> <li>It's super simple to implement, and only costs a constant factor in training time.</li> <li>You could also extend this to 3, 4 side models, and it <em>should</em> improve robustness exponentially if the random factors are correct. If we keep this number small, it will still be constant cost factor.</li> <li>It does <em>not</em> require a perturbation model for input data, which in itself is subject to fine-tuning.</li> </ul> <p>The cons is that:</p> <ul> <li>It does not work so well on low dimensional output spaces. </li> <li>It seems other methods have better results than this one.</li> <li>It only works for models that output probability distributions (hard to extend to object detection, generation and other tasks)</li> </ul> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <h2 id="Future-Work">Future Work<a class="anchor-link" href="#Future-Work"> </a></h2> </div> </div> </div> <div class="cell border-box-sizing text_cell rendered"><div class="inner_cell"> <div class="text_cell_render border-box-sizing rendered_html"> <p>There is a lot more experiments necessary to verify that the hypothesis in favor of that approach hold. Try to find ways to implement that in other tasks. How to improve out-of-distribution detection.</p> </div> </div> </div> </div>Model based encodings (3)2019-08-06T00:00:00+02:002019-08-06T00:00:00+02:00http://localhost:4000/narsil.github.io/ml/nlp/2019/08/06/model-based-bpe-encodings-3<p>In the <a href="/narsil.github.io/ml/nlp/2019/05/16/model-based-bpe-encodings.html">first segment</a> we looked into how we could make a BPE based encoding, not only based on frequency in the dataset, but directly on the model probability measure of the next token. In that article I mention that dynamic BPE are costly because they stop being a one time operation but have to be done for every batch because the vocabulary might have changed. In this article I try to completely remove the “static” BPE approach and replace it completely with ML blocks.</p> <blockquote> <h1 id="tldr-in-this-article-we-present-an-idea-to-replace-classical-bpe-algorithm-with-a-pure-ml-version-of-it">TL;DR In this article we present an idea to replace classical BPE algorithm with a pure ML version of it.</h1> </blockquote> <h2 id="what-is-the-goal-">What is the goal ?</h2> <p>So the goal is to replace BPE algorithm. So it’s go from something like</p> <p>“T|h|e| |c|a|t| |a|t|e| |t|h|e| |a|p|p|l|e|.”</p> <p>To something that has less elements :</p> <p>“The |ca|t |at|e |the| |app|le|.”</p> <p>In one sentence, BPE fuses bytes to form tokens based on frequency in the full dataset. For a more detailed example, look that <a href="/narsil.github.io/ml/nlp/2019/05/16/model-based-bpe-encodings.html">the previous article</a>. In this example, you can see there is always a split after a space. That’s a limitation of BPE so actually our target might look different, maybe more like</p> <p>“The cat |at|e |the app|le|.”</p> <p>Here we can notice that “The cat” is a full token and contain 2 actual words. So the goal is to fuse some starting bytes into N tokens (let’s say ~10k) that hopefully capture regularities in our dataset and are at least correlated to frequency in the original dataset like BPE was.</p> <p>Another property we need to have from BPE is that it can encode an arbitrary string of text. It does not matter if it’s not the same language or even if it makes sense, you CAN encode it, that is a very desirable property. It avoids the <a href="https://medium.com/cisco-emerge/creating-semantic-representations-of-out-of-vocabulary-words-for-common-nlp-tasks-842dbdafba18">out-of-vocabulary</a> problem.</p> <h2 id="approach">Approach</h2> <h3 id="tokenization">Tokenization</h3> <p>So let’s imagine we have a trained transformer like <a href="https://openai.com/blog/better-language-models/">GPT-2</a>. But trained on bytes directly NOT on tokens like the original transformer. Now we can use the idea that when a model is highly confident, it probably means that what it’s about to predict is “in the same token”. Let’s take an example. Try to predict the following Character (as in a single letter) in the next 2 sentences</p> <blockquote> <p>Sentence 1: “Who are yo…”</p> </blockquote> <blockquote> <p>Sentence 2 : “I like …”</p> </blockquote> <p>In the first sentence, normally you would vote with very high confidence for “u”, whereas in the second sentence, you lack a lot of context to be exactly sure on what’s coming next. So “you” would be a token, whereas “like …” can’t be a single token, it has to be at least 2, “like “ and “…”.</p> <p>Here is a small gif of actual probabilities of the language model on a small sentence</p> <p><img src="/narsil.github.io/images/models-2-approach.gif" /></p> <p>You can see the in the left of the graph the probabilities drop, those are the tokens that try to get predicted but are missing context (because we have very few characters before them. For the right side, you can see the drops in probability are pretty consistent and correspond to word boundaries most often.</p> <h3 id="handling-unknown-tokens">Handling unknown tokens</h3> <p>Now we know how we are going to “fuse” characters, but we are not done yet. BPE tokens are a discrete SET of identified values from 0 to N (~10k in this experiment). Also BPE can encode an arbitrary new string by using it’s fusion table. So we can’t just run our algorithm on some specific dataset, count all the tokens created and declare that these are the N tokens for eternity. Let’s imagine I feed my algorithm a new sentence, in a different language, French for instance.</p> <p>“J’adore l’Italie.”</p> <p>We can run our “tokenizer” on this, and receive something like this</p> <p>“J|’|ado|re |l’|Ita|lie.”</p> <p>Now “ado” might not be in our original list, so what do we do with it ? Do we declare the token wrong and split it ? That would be odd.</p> <p>A key insight, is to remember that the first step of the discrete “token” once it enters the model (all of them do that, it’s really not specific to transformer or GPT-2) it gets embedded, meaning we go from a number between 1 and N, to a vector in <em>d</em> dimension space (<em>d</em> is between 100 and 1000 generally).</p> <p>For instance token 3 gets mapped to [0.3, -0.15, 1.4, …] while token 4 gets mapped to [-2.4, -0.014, 0.45, …]</p> <p>So the idea it to generate directly a token embedding (a vector in <em>d</em>-dimension), not necessarily a discrete value (a number between 0 and vocabulary size).</p> <p>In order to do that we need that all tokens should now be represented in the same way by a <em>d</em> dimension space vector. One way to achieve that is to use an autoencoder.</p> <p><img src="https://upload.wikimedia.org/wikipedia/commons/2/28/Autoencoder_structure.png" alt="" /> or with code</p> <p>The core idea is that when we encounter a new unseen token like “ado” it will still have a representation through the VAE, and will probably be close to a known token like “add”. This can help the network overcome odd tokenization or spelling errors.</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">## The name is VAE but I didn't use the internal KL loss in the end as it prevented/slowed down the learning. </span><span class="k">class</span> <span class="nc">VAE</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="nb">super</span><span class="p">(</span><span class="n">VAE</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">M</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">CONTEXT_SIZE</span> <span class="o">*</span> <span class="n">config</span><span class="o">.</span><span class="n">EMBEDDING_DIM</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span> <span class="n">m</span> <span class="o">=</span> <span class="mi">400</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">,</span> <span class="n">m</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc21</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">config</span><span class="o">.</span><span class="n">EMBEDDING_DIM</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc22</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">config</span><span class="o">.</span><span class="n">EMBEDDING_DIM</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc3</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">EMBEDDING_DIM</span><span class="p">,</span> <span class="n">m</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc4</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">)</span> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="c1"># x is [Batch, Context size, Embedding dim] </span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">)</span> <span class="n">h1</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc21</span><span class="p">(</span><span class="n">h1</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc22</span><span class="p">(</span><span class="n">h1</span><span class="p">)</span> <span class="k">def</span> <span class="nf">reparameterize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">):</span> <span class="n">std</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">logvar</span><span class="p">)</span> <span class="n">eps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">std</span><span class="p">)</span> <span class="k">return</span> <span class="n">mu</span> <span class="o">+</span> <span class="n">eps</span> <span class="o">*</span> <span class="n">std</span> <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z</span><span class="p">):</span> <span class="n">h3</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc3</span><span class="p">(</span><span class="n">z</span><span class="p">))</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc4</span><span class="p">(</span><span class="n">h3</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">config</span><span class="o">.</span><span class="n">CONTEXT_SIZE</span><span class="p">,</span> <span class="n">config</span><span class="o">.</span><span class="n">EMBEDDING_DIM</span><span class="p">)</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reparameterize</span><span class="p">(</span><span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">)</span> <span class="k">return</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> </code></pre></div></div> <h3 id="final-network">Final network</h3> <p><img src="/narsil.github.io/images/model-based-2.png" /></p> <h2 id="results">Results</h2> <p>Here is a summary of the values of the tokenization we got.</p> <table> <thead> <tr> <th> </th> <th>Raw</th> <th>BPE</th> <th>Model based</th> </tr> </thead> <tbody> <tr> <td>Vocabulary size</td> <td>256</td> <td>10000</td> <td>26262</td> </tr> <tr> <td>#Tokens</td> <td>387k</td> <td>90k</td> <td>92k</td> </tr> <tr> <td>Avg token length</td> <td>1</td> <td>3.3</td> <td>6.65</td> </tr> </tbody> </table> <p>Here is a excerpt of the kind of tokenization we created</p> <pre><i>|He w|as on|e of| the |most |n|oticea|ble member|s of the| Reform| Club|, |th|ough| he| s|eemed |always |to |avoid |att|racting at|tention|; an en|ig|mat|i|cal |p|erson|age|,| |ab|out whom l|ittle| was |known|, |e|xc|ept that |he| w|as |a |poli|shed m|an| o|f |th|e |wo|rld|. |Pe|ople sa|id| that h|e |re|sembl|ed| |Byron|--at least| t|hat |his hea|d w|as |Byronic|; |but| he was |a |b|earde|d, tranquil| Byron|, who| |might live| on a |thousand year|s |w|ithout g|r|owing o|ld|.| |Certainly| an| English|man|, it |was |m|ore |doubt|ful w|h|ether |Phileas Fogg| w|as |a |London|er|.</i></pre> <p><a href="/txt/80day_tokenized_exp2.txt">Full text</a></p> <p>This version has been done with epsilon=0.0015.</p> <p>As you can see, “Phileas Fogg” is already a token in this situation, which is a multi-word token not achievable by regular BPE. You can also see, a lot of words contain only single bytes tokens which is why this method compresses LESS than regular BPE at the same vocabulary size. Another note is that classical words like “was” is already a token (in the last sentence) but it’s not always the case, this token is context dependent now !</p> <h2 id="vae">VAE</h2> <p>After the VAE step, the reconstruction is not perfect yet perfectly legible.</p> <pre><i>|He w|as on|e of| the |most |n|oticea|ihe member|s of the| reform| Club|, |th|ough| he| s|eemed |always |to |asoid |att|nacting at|tention|, an en|ig|mat|i|cal |p|erson|age|,| |ab| it whom l|ittle| was | nown|, |e|xc| pt that |he| w|as |a |poli|shed m|an| o|f |th|e |wo|rld|. |Pe|ople sa|id| that h|e |re|sembl|ed| |pyron| cat least| t|hat |has hea|d w|as |blronic|; |but| he was |a |b|earde|in tranquil| pyron| who| |eight live| on a |dar and year|s |w|ithout g|r|owing o|ld|.| |rertainly| an| English|man|, it |was |m|ore |doubt|ful w|h|ether |Phileas Fogg| w|as |a |London|er|.</i></pre> <p><a href="/txt/80day_reconstructed2.txt">Full text</a></p> <p>Most of the errors tend to lie in the first characters of <em>long tokens</em>.That’s because, I’m forced to padd the input of the VAE and to mask that padding. In practice that means that the first characters of long tokens get updated less that the others so necessarily they contain more errors. <a href="#notes">More information</a>.</p> <h2 id="upper-level">Upper level</h2> <p>In order to complete the experiment, we need to check that the original language model done directly at BPE level can be done with this new model-based BPE encoding.</p> <p>It’s pretty slow to train that upper level because we need to flow the gradients all the way through the VAE decoder, and the lower layer decoding step, in order to get the <strong>character level loss</strong> (softmax + nll_loss) to properly train something. That’s a limit of the current approach.</p> <p>If we randomly split the text into train&amp;validation, we can learn almost perfectly (97% top-1 character level accuracy) the language model on top of that Model based BPE.</p> <p><img src="/narsil.github.io/images/models-2-overfit.png" /></p> <p>However this can be considered <strong>overfitting</strong> because even though a specific input was never seen in the valid set, a very close one <em>was</em>.</p> <p>If instead we try to compare with a fixed split, where the last part of the book is considered the valid set, then we get much lower result.</p> <p>We could achieve 25% exact character matching, and ~77% top-10 character matching on the valid set, which is the end of the book ! The same results happen with BPE, even worse ! we can’t get past 13% top-1 and 25% top-10 on the regular BPE. That’s understandable because the dataset is very small and the last part of the book is different so it’s very hard to infer it from just the beginning and no other text.</p> <p>Another note, is that model based BPE are not tokenizing deterministicly, there is some variance to it, depending on the context of a particular word. This actually seems to be a good property (See <a href="https://arxiv.org/abs/1804.10959">this</a>) and might explain away the better performance of model based BPE over regular BPE. Keep in mind it’s 25% of the <strong>characters</strong> that are correct. If we looked at a discrete view of <strong>tokens</strong> we probably would have a much higher prediction rate (it’s left for future work for now).</p> <p>Here is a picture from the tensorboard values, P_1 is probability that the character predicted is the correct one, P_10 is that it is in the top-10 values.</p> <p><img src="/narsil.github.io/images/models-2-upper.png" /></p> <p>The overfitting starts happening around the ~1M steps mark.</p> <h3 id="notes">Notes</h3> <ul> <li>In the experiment we learned model by model, freezing the lower model before training something on top. It’s because the batching of different layers occur differently. Learning the whole thing end-to-end is probably going to need some thought. The batching is easy for the lower level, every batch needs a tensor of shape CONTEXT_SIZE (=64) of [0-255] ints. For the VAE, we need to have a variable length (depending on the length token) times EMBEDDING_DIM (=128). The upper level needs only tensors of size CONTEXT_SIZE * EMBEDDING_DIM yet if we want to try and end-to-end training, we have <strong>no idea</strong> how many bytes we need to generate 1 correct tensor in the upper layer. We know it’s no more than CONTEXT_SIZE² but that would be prohibitive to use that value.</li> <li>The loss NEEDS to always be the byte-level nll loss. At first I thought a simple MSE loss in the embedding space could be enough to learn the proper models. It seems to not be the case. I could only achieve meaningful results by always referring to the original strings and calculating the NLL Loss. When using this loss, the MSE actually <em>increases</em>. This leads me to think that encoding/decoding + softmax are highly anisotropic operators. Looking at the singular values of the embedding matrix, we can see that the highest one is 7.35, the lowest one 0.12, so there are 2 orders of magnitude between the 2. This anisotropy means that the MSE loss which considers all dimensions of the embeddding equal is actually couting way too much some irrelevant dimensions. It would be much faster and simpler if we could train directly on MSE (it would enable us to train without running all the decoding steps to generate the loss). So we need to add some spectral loss on the embedding on the lower language model to test that hypothesis.</li> <li>The tokens have variable lengths. In order to fix this, we have to padd all sequences during learning. Because we padd, we have to mask the padding during training for both VAE and upper LM. Keeping track of this is pretty nifty and it means gradients on rarely used places will rarely get updated. So we will almost surely miss some letters in our tokens. Either at the front or the end of the token depending on how we padd the tokens.</li> </ul> <h2 id="future-work"><strong>Future work</strong></h2> <ul> <li>Actually testing discretizing the tokens to compare with the regular BPE. In that direction, also comparing with a randomized tokenizer as used in <a href="https://github.com/google/sentencepiece">SentencePiece</a> to make sure the results are actually comparable and are indeed linked to tokenization variance.</li> <li>The masking problem really seems to be a current limit of the model. Finding a workaround would be really valuable.</li> <li>The fact that the NLL loss is required slows down upper layers. It would be awesome if we could smooth out the encoding/decoding matrix so that L2 directly for VAE and the upper layer works. It probably goes against regular language model embedding so not sure it’s doable.</li> <li>Making the epsilon based tokenization directly after the embedding layer. This would help <em>stack</em> those levels hopefully learning higher and higer representations of text leading the sentence embedding and so on.</li> <li>On the same idea, another direction would be to do actual discrete tokenization to allow for the models to stack.</li> </ul>nicolasIn the first segment we looked into how we could make a BPE based encoding, not only based on frequency in the dataset, but directly on the model probability measure of the next token. In that article I mention that dynamic BPE are costly because they stop being a one time operation but have to be done for every batch because the vocabulary might have changed. In this article I try to completely remove the “static” BPE approach and replace it completely with ML blocks.Model based encodings (2)2019-06-06T00:00:00+02:002019-06-06T00:00:00+02:00http://localhost:4000/narsil.github.io/ml/nlp/2019/06/06/model-based-bpe-encodings-2<p>In the <a href="/narsil.github.io/ml/nlp/2019/05/16/model-based-bpe-encodings.html">first segment</a> we looked into how we could make a BPE based encoding, not only based on frequency in the dataset, but directly on the model probability measure of the next token. In that article I mention that dynamic BPE are costly because they stop being a one time operation but have to be done for every batch because the vocabulary might have changed. In this article I try to completely remove the “static” BPE approach and replace it completely with ML blocks.</p> <blockquote> <h1 id="tldr-in-this-article-we-present-an-idea-to-replace-classical-bpe-algorithm-with-a-pure-ml-version-of-it">TL;DR In this article we present an idea to replace classical BPE algorithm with a pure ML version of it.</h1> </blockquote> <h2 id="what-is-the-goal-">What is the goal ?</h2> <p>So the goal is to replace BPE algorithm. So it’s go from something like</p> <p>“T|h|e| |c|a|t| |a|t|e| |t|h|e| |a|p|p|l|e|.”</p> <p>To something that has less elements :</p> <p>“The |ca|t |at|e |the| |app|le|.”</p> <p>In one sentence, BPE fuses bytes to form tokens based on frequency in the full dataset. For a more detailed example, look that <a href="/narsil.github.io/ml/nlp/2019/05/16/model-based-bpe-encodings.html">the previous article</a>. In this example, you can see there is always a split after a space. That’s a limitation of BPE so actually our target might look different, maybe more like</p> <p>“The cat |at|e |the app|le|.”</p> <p>Here we can notice that “The cat” is a full token and contain 2 actual words. So the goal is to fuse some starting bytes into N tokens (let’s say ~10k) that hopefully capture regularities in our dataset and are at least correlated to frequency in the original dataset like BPE was.</p> <p>Another property we need to have from BPE is that it can encode an arbitrary string of text. It does not matter if it’s not the same language or even if it makes sense, you CAN encode it, that is a very desirable property. It avoids the <a href="https://medium.com/cisco-emerge/creating-semantic-representations-of-out-of-vocabulary-words-for-common-nlp-tasks-842dbdafba18">out-of-vocabulary</a> problem.</p> <h2 id="approach">Approach</h2> <h3 id="tokenization">Tokenization</h3> <p>So let’s imagine we have a trained transformer like <a href="https://openai.com/blog/better-language-models/">GPT-2</a>. But trained on bytes directly NOT on tokens like the original transformer. Now we can use the idea that when a model is highly confident, it probably means that what it’s about to predict is “in the same token”. Let’s take an example. Try to predict the following Character (as in a single letter) in the next 2 sentences</p> <blockquote> <p>Sentence 1: “Who are yo…”</p> </blockquote> <blockquote> <p>Sentence 2 : “I like …”</p> </blockquote> <p>In the first sentence, normally you would vote with very high confidence for “u”, whereas in the second sentence, you lack a lot of context to be exactly sure on what’s coming next. So “you” would be a token, whereas “like …” can’t be a single token, it has to be at least 2, “like “ and “…”.</p> <p>Here is a small gif of actual probabilities of the language model on a small sentence</p> <p><img src="/narsil.github.io/images/models-2-approach.gif" /></p> <p>You can see the in the left of the graph the probabilities drop, those are the tokens that try to get predicted but are missing context (because we have very few characters before them. For the right side, you can see the drops in probability are pretty consistent and correspond to word boundaries most often.</p> <h3 id="handling-unknown-tokens">Handling unknown tokens</h3> <p>Now we know how we are going to “fuse” characters, but we are not done yet. BPE tokens are a discrete SET of identified values from 0 to N (~10k in this experiment). Also BPE can encode an arbitrary new string by using it’s fusion table. So we can’t just run our algorithm on some specific dataset, count all the tokens created and declare that these are the N tokens for eternity. Let’s imagine I feed my algorithm a new sentence, in a different language, French for instance.</p> <p>“J’adore l’Italie.”</p> <p>We can run our “tokenizer” on this, and receive something like this</p> <p>“J|’|ado|re |l’|Ita|lie.”</p> <p>Now “ado” might not be in our original list, so what do we do with it ? Do we declare the token wrong and split it ? That would be odd.</p> <p>A key insight, is to remember that the first step of the discrete “token” once it enters the model (all of them do that, it’s really not specific to transformer or GPT-2) it gets embedded, meaning we go from a number between 1 and N, to a vector in <em>d</em> dimension space (<em>d</em> is between 100 and 1000 generally).</p> <p>For instance token 3 gets mapped to [0.3, -0.15, 1.4, …] while token 4 gets mapped to [-2.4, -0.014, 0.45, …]</p> <p>So the idea it to generate directly a token embedding (a vector in <em>d</em>-dimension), not necessarily a discrete value (a number between 0 and vocabulary size).</p> <p>In order to do that we need that all tokens should now be represented in the same way by a <em>d</em> dimension space vector. One way to achieve that is to use an autoencoder.</p> <p><img src="https://upload.wikimedia.org/wikipedia/commons/2/28/Autoencoder_structure.png" alt="" /> or with code</p> <p>The core idea is that when we encounter a new unseen token like “ado” it will still have a representation through the VAE, and will probably be close to a known token like “add”. This can help the network overcome odd tokenization or spelling errors.</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">## The name is VAE but I didn't use the internal KL loss in the end as it prevented/slowed down the learning. </span><span class="k">class</span> <span class="nc">VAE</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="nb">super</span><span class="p">(</span><span class="n">VAE</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">M</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">CONTEXT_SIZE</span> <span class="o">*</span> <span class="n">config</span><span class="o">.</span><span class="n">EMBEDDING_DIM</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span> <span class="n">m</span> <span class="o">=</span> <span class="mi">400</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">,</span> <span class="n">m</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc21</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">config</span><span class="o">.</span><span class="n">EMBEDDING_DIM</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc22</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">config</span><span class="o">.</span><span class="n">EMBEDDING_DIM</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc3</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">EMBEDDING_DIM</span><span class="p">,</span> <span class="n">m</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc4</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">)</span> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="c1"># x is [Batch, Context size, Embedding dim] </span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">)</span> <span class="n">h1</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc21</span><span class="p">(</span><span class="n">h1</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc22</span><span class="p">(</span><span class="n">h1</span><span class="p">)</span> <span class="k">def</span> <span class="nf">reparameterize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">):</span> <span class="n">std</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">logvar</span><span class="p">)</span> <span class="n">eps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">std</span><span class="p">)</span> <span class="k">return</span> <span class="n">mu</span> <span class="o">+</span> <span class="n">eps</span> <span class="o">*</span> <span class="n">std</span> <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z</span><span class="p">):</span> <span class="n">h3</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc3</span><span class="p">(</span><span class="n">z</span><span class="p">))</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc4</span><span class="p">(</span><span class="n">h3</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">config</span><span class="o">.</span><span class="n">CONTEXT_SIZE</span><span class="p">,</span> <span class="n">config</span><span class="o">.</span><span class="n">EMBEDDING_DIM</span><span class="p">)</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reparameterize</span><span class="p">(</span><span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">)</span> <span class="k">return</span> <span class="n">mu</span><span class="p">,</span> <span class="n">logvar</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> </code></pre></div></div> <h3 id="final-network">Final network</h3> <p><img src="/narsil.github.io/images/model-based-2.png" /></p> <h2 id="results">Results</h2> <p>Here is a summary of the values of the tokenization we got.</p> <table> <thead> <tr> <th> </th> <th>Raw</th> <th>BPE</th> <th>Model based</th> </tr> </thead> <tbody> <tr> <td>Vocabulary size</td> <td>256</td> <td>10000</td> <td>26262</td> </tr> <tr> <td>#Tokens</td> <td>387k</td> <td>90k</td> <td>92k</td> </tr> <tr> <td>Avg token length</td> <td>1</td> <td>3.3</td> <td>6.65</td> </tr> </tbody> </table> <p>Here is a excerpt of the kind of tokenization we created</p> <pre><i>|He w|as on|e of| the |most |n|oticea|ble member|s of the| Reform| Club|, |th|ough| he| s|eemed |always |to |avoid |att|racting at|tention|; an en|ig|mat|i|cal |p|erson|age|,| |ab|out whom l|ittle| was |known|, |e|xc|ept that |he| w|as |a |poli|shed m|an| o|f |th|e |wo|rld|. |Pe|ople sa|id| that h|e |re|sembl|ed| |Byron|--at least| t|hat |his hea|d w|as |Byronic|; |but| he was |a |b|earde|d, tranquil| Byron|, who| |might live| on a |thousand year|s |w|ithout g|r|owing o|ld|.| |Certainly| an| English|man|, it |was |m|ore |doubt|ful w|h|ether |Phileas Fogg| w|as |a |London|er|.</i></pre> <p><a href="/txt/80day_tokenized_exp2.txt">Full text</a></p> <p>This version has been done with epsilon=0.0015.</p> <p>As you can see, “Phileas Fogg” is already a token in this situation, which is a multi-word token not achievable by regular BPE. You can also see, a lot of words contain only single bytes tokens which is why this method compresses LESS than regular BPE at the same vocabulary size. Another note is that classical words like “was” is already a token (in the last sentence) but it’s not always the case, this token is context dependent now !</p> <h2 id="vae">VAE</h2> <p>After the VAE step, the reconstruction is not perfect yet perfectly legible.</p> <pre><i>|He w|as on|e of| the |most |n|oticea|ihe member|s of the| reform| Club|, |th|ough| he| s|eemed |always |to |asoid |att|nacting at|tention|, an en|ig|mat|i|cal |p|erson|age|,| |ab| it whom l|ittle| was | nown|, |e|xc| pt that |he| w|as |a |poli|shed m|an| o|f |th|e |wo|rld|. |Pe|ople sa|id| that h|e |re|sembl|ed| |pyron| cat least| t|hat |has hea|d w|as |blronic|; |but| he was |a |b|earde|in tranquil| pyron| who| |eight live| on a |dar and year|s |w|ithout g|r|owing o|ld|.| |rertainly| an| English|man|, it |was |m|ore |doubt|ful w|h|ether |Phileas Fogg| w|as |a |London|er|.</i></pre> <p><a href="/txt/80day_reconstructed2.txt">Full text</a></p> <p>Most of the errors tend to lie in the first characters of <em>long tokens</em>.That’s because, I’m forced to padd the input of the VAE and to mask that padding. In practice that means that the first characters of long tokens get updated less that the others so necessarily they contain more errors. <a href="#notes">More information</a>.</p> <h2 id="upper-level">Upper level</h2> <p>In order to complete the experiment, we need to check that the original language model done directly at BPE level can be done with this new model-based BPE encoding.</p> <p>It’s pretty slow to train that upper level because we need to flow the gradients all the way through the VAE decoder, and the lower layer decoding step, in order to get the <strong>character level loss</strong> (softmax + nll_loss) to properly train something. That’s a limit of the current approach.</p> <p>If we randomly split the text into train&amp;validation, we can learn almost perfectly (97% top-1 character level accuracy) the language model on top of that Model based BPE.</p> <p><img src="/narsil.github.io/images/models-2-overfit.png" /></p> <p>However this can be considered <strong>overfitting</strong> because even though a specific input was never seen in the valid set, a very close one <em>was</em>.</p> <p>If instead we try to compare with a fixed split, where the last part of the book is considered the valid set, then we get much lower result.</p> <p>We could achieve 25% exact character matching, and ~77% top-10 character matching on the valid set, which is the end of the book ! The same results happen with BPE, even worse ! we can’t get past 13% top-1 and 25% top-10 on the regular BPE. That’s understandable because the dataset is very small and the last part of the book is different so it’s very hard to infer it from just the beginning and no other text.</p> <p>Another note, is that model based BPE are not tokenizing deterministicly, there is some variance to it, depending on the context of a particular word. This actually seems to be a good property (See <a href="https://arxiv.org/abs/1804.10959">this</a>) and might explain away the better performance of model based BPE over regular BPE. Keep in mind it’s 25% of the <strong>characters</strong> that are correct. If we looked at a discrete view of <strong>tokens</strong> we probably would have a much higher prediction rate (it’s left for future work for now).</p> <p>Here is a picture from the tensorboard values, P_1 is probability that the character predicted is the correct one, P_10 is that it is in the top-10 values.</p> <p><img src="/narsil.github.io/images/models-2-upper.png" /></p> <p>The overfitting starts happening around the ~1M steps mark.</p> <h3 id="notes">Notes</h3> <ul> <li>In the experiment we learned model by model, freezing the lower model before training something on top. It’s because the batching of different layers occur differently. Learning the whole thing end-to-end is probably going to need some thought. The batching is easy for the lower level, every batch needs a tensor of shape CONTEXT_SIZE (=64) of [0-255] ints. For the VAE, we need to have a variable length (depending on the length token) times EMBEDDING_DIM (=128). The upper level needs only tensors of size CONTEXT_SIZE * EMBEDDING_DIM yet if we want to try and end-to-end training, we have <strong>no idea</strong> how many bytes we need to generate 1 correct tensor in the upper layer. We know it’s no more than CONTEXT_SIZE² but that would be prohibitive to use that value.</li> <li>The loss NEEDS to always be the byte-level nll loss. At first I thought a simple MSE loss in the embedding space could be enough to learn the proper models. It seems to not be the case. I could only achieve meaningful results by always referring to the original strings and calculating the NLL Loss. When using this loss, the MSE actually <em>increases</em>. This leads me to think that encoding/decoding + softmax are highly anisotropic operators. Looking at the singular values of the embedding matrix, we can see that the highest one is 7.35, the lowest one 0.12, so there are 2 orders of magnitude between the 2. This anisotropy means that the MSE loss which considers all dimensions of the embeddding equal is actually couting way too much some irrelevant dimensions. It would be much faster and simpler if we could train directly on MSE (it would enable us to train without running all the decoding steps to generate the loss). So we need to add some spectral loss on the embedding on the lower language model to test that hypothesis.</li> <li>The tokens have variable lengths. In order to fix this, we have to padd all sequences during learning. Because we padd, we have to mask the padding during training for both VAE and upper LM. Keeping track of this is pretty nifty and it means gradients on rarely used places will rarely get updated. So we will almost surely miss some letters in our tokens. Either at the front or the end of the token depending on how we padd the tokens.</li> </ul> <h2 id="future-work"><strong>Future work</strong></h2> <ul> <li>Actually testing discretizing the tokens to compare with the regular BPE. In that direction, also comparing with a randomized tokenizer as used in <a href="https://github.com/google/sentencepiece">SentencePiece</a> to make sure the results are actually comparable and are indeed linked to tokenization variance.</li> <li>The masking problem really seems to be a current limit of the model. Finding a workaround would be really valuable.</li> <li>The fact that the NLL loss is required slows down upper layers. It would be awesome if we could smooth out the encoding/decoding matrix so that L2 directly for VAE and the upper layer works. It probably goes against regular language model embedding so not sure it’s doable.</li> <li>Making the epsilon based tokenization directly after the embedding layer. This would help <em>stack</em> those levels hopefully learning higher and higer representations of text leading the sentence embedding and so on.</li> <li>On the same idea, another direction would be to do actual discrete tokenization to allow for the models to stack.</li> </ul>nicolasIn the first segment we looked into how we could make a BPE based encoding, not only based on frequency in the dataset, but directly on the model probability measure of the next token. In that article I mention that dynamic BPE are costly because they stop being a one time operation but have to be done for every batch because the vocabulary might have changed. In this article I try to completely remove the “static” BPE approach and replace it completely with ML blocks.Model based encodings2019-05-16T00:00:00+02:002019-05-16T00:00:00+02:00http://localhost:4000/narsil.github.io/ml/nlp/2019/05/16/model-based-bpe-encodings<p><a href="https://en.wikipedia.org/wiki/Byte_pair_encoding">Byte-pair encodings</a> (BPE) are now very commonly used in NLP. In <a href="https://openai.com/blog/better-language-models/">GPT-2</a>, Byte-pair encodings are used to preformat the raw texts before feeding the model. But this is a relatively costly step for your preprocessing and has some limitations. For instance, you have to split your data on spaces if you want your byte pair algorithm to compute in reasonable time.</p> <blockquote> <h1 id="tldr-in-this-article-we-present-an-idea-to-generate-byte-pair-encodings-not-based-on-frequency-in-the-dataset-but-on-the-quality-of-the-prediction-of-our-model-this-enables-us-to-predict-multi-word-tokens-like-new-york-and-address-languages-that-dont-use-spaces-to-split-words">TL;DR In this article we present an idea to generate Byte pair encodings, not based on frequency in the dataset, but on the quality of the prediction of our model. This enables us to predict multi word tokens like “New York” and address languages that don’t use spaces to split words.</h1> </blockquote> <h2 id="what-are-byte-pair-encodings-">What are Byte Pair Encodings ?</h2> <p>Byte-pair encodings are a way to compress information from pairs of bytes that will form tokens. Let’s take an example :</p> <p>“I love carrots and I love apples.”</p> <p>This sentence read by a computer is only a sequence of bytes (bytes are simply a number between 0 and 255). That means to a computer our sentence looks like</p> <p>“I love carrots and I love apples.” -&gt; [73, 32, 108, 111, 118, 101, 32, 99, 97, 114, 114, 111, 116, 115, 32, 97, 110, 100, 32, 73, 32, 108, 111, 118, 101, 32, 97, 112, 112, 108, 101, 115, 46]</p> <p>From that example, you may remark that some bytes are occurring multiple times together like [108, 111] that occurs twice (it’s “lo” from “love”). So let’s build a new token for this frequent pair. Numbers from 0 to 255 are already taken so we’ll take the next available number which is 256, and we are going to store that information in a table</p> <p>[108, 111] -&gt; 256</p> <p>Now if we use that new token to encode our original bytes, whenever we encounter [108, 111], we’ll replace that by 256, so the original byte string becomes :</p> <p>[73, 32, 108, <strong>256</strong>, 101, 32, 99, 97, 114, 114, 111, 116, 115, 32, 97, 110, 100, 32, 73, 32, <strong>256</strong>, 118, 101, 32, 97, 112, 112, 108, 101, 115, 46]</p> <p>We went from 33 numbers to 31 numbers. We can rinse and repeat to compress the number of numbers even further. Originally, BPE was proposed as a compression algorithm. It’s not the best compression tool, so we won’t look at that side of the algorithm. Now you get what we are looking at when we train a model on BPEs, just a list of numbers.</p> <p>Typically a BPE vocabulary contains ~10k tokens (GPT-2 has 50k), that means it can capture very frequent words like “the” entirely, and parts of words that contain many variations like “ment” (<strong>ment</strong>ally, environ<strong>ment</strong> …). What’s great about it it that you can now have words share semantic parts of them for their representation in your model so (environ-ment, environ-ment-al, environ-ment-ally will all share “environ” which will contain most of the semantic meaning, the rest will contain grammar information hopefully).</p> <p>The real advantage of BPE over classical Word Embeddings is that it does not fall into the out-of-vocabulary error (when a word was not seen). At worse you can always fall back to single bytes.</p> <h2 id="whats-the-problem-with-bpe-"><strong>What’s the problem with BPE ?</strong></h2> <p>BPE algorithm is pretty bad in terms of complexity to calculate (roughly O(n²), you can look at a very good implementation <a href="https://github.com/glample/fastBPE">https://github.com/glample/fastBPE</a>). BPE is also pretty bad when you want to encode some new text. A greedy algorithm will be O(n) but not the best encoding possible, the best encoding possible is actually O(n²) in the general case.</p> <p>To be honest, most implementations split on spaces as mentioned earlier which speeds up the algorithm quite a bit. Once we have encoded a full word like “the” there is no way to add tokens to it, so it’s not necessary to look at it anymore for potential byte pairs, so we can assume the encoding&amp;table creation go from O(n²) to something much closer to O(n). In addition, at encoding time, once we know the encoding for “the” we can cache that information leading to further speed ups. But using spaces as a special character has drawbacks, namely:</p> <ul> <li> <p>We can’t address as well languages that don’t use a space to separate words like Chinese (arguably German).</p> </li> <li> <p>We can’t encode frequently occurring multi words like “New York” or “European Union” or “black holes”</p> </li> </ul> <p>The second problem is especially bad when you consider examples where semantic is very different from the composing words like “Chicago Bulls” have nothing to do with bulls.</p> <h2 id="ε-bpe-or-model-based-bpe-encoding"><strong>ε-BPE or model based BPE encoding</strong></h2> <p>The core idea is that instead of using frequency in the dataset to create the byte pairs, we can use the probability transition of the model to create the BPE. Let’s use some kind of transformer, GPT-2 for instance. The core idea of that model, is to predict the next token (in the BPE sense) given a fixed context size. But we can use the output probability of the model in order to create new tokens, not because they are frequent but because they are easy to predict. For instance in a book that contains a character “Sir Francis” that appears rarely, but there is only one character named “Sir …”, the algorithm might learn quite easily that “Sir “ is followed by “Francis” with great confidence, even if the occurence of the words is pretty low compared to common words like “the”, “like” and “I”.</p> <p>So the core algorithm, will train a simple transformer on a dataset on regular bytes (at least at the start). Then, as the algorithm learns, some predictions will be above 1-ε. We can keep track of those and keep track of the last token we received, to check if we were correct.</p> <p>Let’s keep a hit map to see how successful our algorithm is. For instance, I predicted “Fo” will be followed by “gg” (Phileas Fogg is a character in Around the world in 80 days) with probability &gt; 1-ε. I was correct in 14 cases, and got it wrong in 1 case (let’s say it was classical “Fo” “g “). We were correct 14/15 times that’s 93% accuracy. If we look at the fluctuation interval associated with that, we get [92.74-93.25%] range. If 92.74 &gt; 1–ε we can conclude that our transition prediction is really very good, it’s not a fluke of the model.</p> <p>More generally, if we want 95% confidence when we upgrade this transition, we need to respect the following inequality : k / n - 1/sqrt(n) &gt; 1-ε, where k is the number of successful predictions, n is the total number of predictions and ε the probability margin explained earlier.</p> <p>This model is slightly different from byte pair encoding, but now we don’t suffer from the 2 problems mentioned above, we can get pretty long tokens if the dataset allows for it, and we can use Chinese or German as the space character does not play any special role.</p> <h2 id="results"><strong>Results</strong></h2> <p>Implementation can be found here. On the first run, we ran on a book <a href="https://en.wikipedia.org/wiki/Around_the_World_in_Eighty_Days">Around the world in 80 days</a> by Jules Verne. It’s a very small dataset but the idea is to check that we can actually overcome BPE’s limitations. Here are a few telling tokens that were created while running on the dataset :</p> <table> <thead> <tr> <th>Promotion #</th> <th>Token created</th> </tr> </thead> <tbody> <tr> <td>338</td> <td>“Mr. Fogg”</td> </tr> <tr> <td>357</td> <td>“Phileas Fogg”</td> </tr> <tr> <td>360</td> <td>“Passepartout”</td> </tr> <tr> <td>635</td> <td>“ir Franc” (Sir Francis)</td> </tr> <tr> <td>781</td> <td>“It was”</td> </tr> <tr> <td>900</td> <td>’” asked’ (contains a quote character)</td> </tr> </tbody> </table> <p>What is interesting, it that:</p> <ul> <li> <p>We managed to create multi word tokens like “Phileas Fogg”</p> </li> <li> <p>Multi word tokens are a minority in terms of tokens created by the algorithm. Out of 421 tokens that contain a space character only 27 are multi word tokens like “New York”. The remaining 394 tokens contain an ending space, meaning our algorithm is learning word boundaries. It is reassuring because traditional BPE are usually hardcoding that information.</p> </li> <li> <p>Multi word tokens are name of characters in the book, which are occurring frequently, they are an entity by themselves (Fogg even has 2 tokens associated to him)</p> </li> <li> <p>2 Multi word tokens are <strong>not</strong> specific to the book, “it was” is a pretty common 2 word token in English in descriptions, “(…) asked” is a very common continuation when we start a quote and end a sentence with a question mark. We can guess that “(…) said” would be a token further down the line, but it’s harder as there are probably a wider variety of verbs that can fit (said, replied, answered and so on…)</p> </li> </ul> <p>Here is a more complete comparison of standard BPE with ε-BPE, with the first 100 tokens generated, as you can see more tokens are dedicated to syntax in eBPE, which Standard BPE ignore gladly by splitting on newlines and spaces.</p> <table> <thead> <tr> <th>Standard BPE</th> <th>eBPE</th> </tr> </thead> <tbody> <tr> <td>‘th’</td> <td>‘\r\n’</td> </tr> <tr> <td>‘the ‘</td> <td>’, ‘</td> </tr> <tr> <td>‘an’</td> <td>‘d ‘</td> </tr> <tr> <td>‘in’</td> <td>‘Th’</td> </tr> <tr> <td>‘ou’</td> <td>‘ve’</td> </tr> <tr> <td>‘er’</td> <td>‘y ‘</td> </tr> <tr> <td>‘ed ‘</td> <td>’; ‘</td> </tr> <tr> <td>‘ar’</td> <td>‘f ‘</td> </tr> <tr> <td>‘hi’</td> <td>’,\r\n’</td> </tr> <tr> <td>‘on’</td> <td>‘\r\n\r\n’</td> </tr> <tr> <td>‘re’</td> <td>‘th’</td> </tr> <tr> <td>‘en’</td> <td>‘qu’</td> </tr> <tr> <td>‘and ‘</td> <td>‘the’</td> </tr> <tr> <td>‘of ‘</td> <td>’ ‘</td> </tr> <tr> <td>‘st’</td> <td>‘the ‘</td> </tr> <tr> <td>‘to ‘</td> <td>‘The’</td> </tr> <tr> <td>‘as ‘</td> <td>‘\r\n’</td> </tr> <tr> <td>‘se’</td> <td>’, ‘</td> </tr> <tr> <td>‘ha’</td> <td>‘y ‘</td> </tr> <tr> <td>‘or’</td> <td>‘d ‘</td> </tr> <tr> <td>’.\r ‘</td> <td>‘Th’</td> </tr> <tr> <td>‘it’</td> <td>‘ve’</td> </tr> <tr> <td>‘he ‘</td> <td>’; ‘</td> </tr> <tr> <td>‘le’</td> <td>‘f ‘</td> </tr> <tr> <td>‘ing ‘</td> <td>’,\r\n’</td> </tr> <tr> <td>’,\r ‘</td> <td>’ ‘</td> </tr> <tr> <td>‘as’</td> <td>‘\r\n’</td> </tr> <tr> <td>‘in ‘</td> <td>’, ‘</td> </tr> <tr> <td>‘at’</td> <td>‘d ‘</td> </tr> <tr> <td>‘at ‘</td> <td>‘y ‘</td> </tr> <tr> <td>‘ro’</td> <td>‘Th’</td> </tr> <tr> <td>‘er ‘</td> <td>‘ve’</td> </tr> <tr> <td>‘al’</td> <td>‘f ‘</td> </tr> <tr> <td>‘es’</td> <td>’; ‘</td> </tr> <tr> <td>‘on ‘</td> <td>’ ‘</td> </tr> <tr> <td>‘was ‘</td> <td>’,\r\n’</td> </tr> <tr> <td>‘no’</td> <td>‘th’</td> </tr> <tr> <td>‘his ‘</td> <td>‘\r\n’</td> </tr> <tr> <td>‘ed’</td> <td>’, ‘</td> </tr> <tr> <td>‘ac’</td> <td>‘d ‘</td> </tr> <tr> <td>’“\r ‘</td> <td>‘y ‘</td> </tr> <tr> <td>‘ri’</td> <td>‘Th’</td> </tr> <tr> <td>‘be’</td> <td>‘ve’</td> </tr> <tr> <td>‘ly ‘</td> <td>‘f ‘</td> </tr> <tr> <td>‘om’</td> <td>’; ‘</td> </tr> <tr> <td>‘li’</td> <td>’ ‘</td> </tr> <tr> <td>‘en ‘</td> <td>’,\r\n’</td> </tr> <tr> <td>‘ti’</td> <td>‘th’</td> </tr> <tr> <td>‘og’</td> <td>‘\r\n\r\n’</td> </tr> <tr> <td>‘ra’</td> <td>‘the’</td> </tr> <tr> <td>‘di’</td> <td>‘the ‘</td> </tr> <tr> <td>‘art’</td> <td>‘The’</td> </tr> <tr> <td>‘Fog’</td> <td>‘qu’</td> </tr> <tr> <td>‘the’</td> <td>’s ‘</td> </tr> <tr> <td>‘ma’</td> <td>‘The ‘</td> </tr> <tr> <td>‘ve ‘</td> <td>‘g ‘</td> </tr> <tr> <td>‘is ‘</td> <td>’,”’</td> </tr> <tr> <td>‘or ‘</td> <td>‘no’</td> </tr> <tr> <td>‘ld ‘</td> <td>‘t ‘</td> </tr> <tr> <td>‘whi’</td> <td>‘th ‘</td> </tr> <tr> <td>‘il’</td> <td>‘o ‘</td> </tr> <tr> <td>‘ur’</td> <td>’?”’</td> </tr> <tr> <td>’s, ‘</td> <td>‘\r\n\r\n”’</td> </tr> <tr> <td>‘de’</td> <td>’,” ‘</td> </tr> <tr> <td>‘wh’</td> <td>‘Mr’</td> </tr> <tr> <td>‘lo’</td> <td>‘e ‘</td> </tr> <tr> <td>‘ch ‘</td> <td>‘yo’</td> </tr> <tr> <td>‘ere ‘</td> <td>‘Yo’</td> </tr> <tr> <td>‘ith ‘</td> <td>‘ou’</td> </tr> <tr> <td>‘The ‘</td> <td>’. ‘</td> </tr> <tr> <td>‘am’</td> <td>‘nd ‘</td> </tr> <tr> <td>‘ent’</td> <td>‘h ‘</td> </tr> <tr> <td>‘un’</td> <td>‘n ‘</td> </tr> <tr> <td>‘gh’</td> <td>’;\r\n’</td> </tr> <tr> <td>‘with ‘</td> <td>‘og’</td> </tr> <tr> <td>‘an ‘</td> <td>‘you’</td> </tr> <tr> <td>‘oun’</td> <td>‘r ‘</td> </tr> <tr> <td>‘part’</td> <td>‘of ‘</td> </tr> <tr> <td>‘ver’</td> <td>‘to ‘</td> </tr> <tr> <td>‘si’</td> <td>’s F’</td> </tr> <tr> <td>‘had ‘</td> <td>‘Pa’</td> </tr> <tr> <td>‘not ‘</td> <td>‘as ‘</td> </tr> <tr> <td>‘ould ‘</td> <td>'’s ‘</td> </tr> <tr> <td>‘ing’</td> <td>’. F’</td> </tr> <tr> <td>‘out ‘</td> <td>‘is ‘</td> </tr> <tr> <td>‘el’</td> <td>‘ld ‘</td> </tr> <tr> <td>‘sa’</td> <td>‘ng ‘</td> </tr> <tr> <td>‘ce’</td> <td>‘at ‘</td> </tr> <tr> <td>‘that ‘</td> <td>‘re’</td> </tr> <tr> <td>‘asse’</td> <td>‘ve ‘</td> </tr> <tr> <td>‘fi’</td> <td>‘gh’</td> </tr> <tr> <td>‘ol’</td> <td>‘ut ‘</td> </tr> <tr> <td>‘sh’</td> <td>‘ll’</td> </tr> <tr> <td>‘r. ‘</td> <td>‘Pas’</td> </tr> <tr> <td>’.”\r ‘</td> <td>‘re ‘</td> </tr> <tr> <td>‘Passe’</td> <td>‘ed ‘</td> </tr> <tr> <td>‘Passepart’</td> <td>’. Fog’</td> </tr> <tr> <td>‘ut ‘</td> <td>‘ch ‘</td> </tr> <tr> <td>‘which ‘</td> <td>‘and ‘</td> </tr> <tr> <td>‘ay’</td> <td>‘ea’</td> </tr> </tbody> </table> <p>I would love to check the tokenization of German or Chinese but I’m not a speaker of either language so it’s hard for me to analyze the results anyway. What’s for sure is that the technique is applicable.</p> <p>I also tried the technique on different types of files like wav files or mp3 files, even jpeg images. Analysis is harder to do. Still some interesting notes, it took longer for the model to emit new tokens on the mp3 files than on the wav files. The mp3 file is encoded, therefore should have a lower entropy (meaning it’s harder to predict the next token) than the wav files so the model takes longer to actually get good at predicting. It’s probable (I haven’t checked) that we have to overfit the mp3 file and jpeg files before we can predict any meaningful content (except maybe the header part)</p> <h2 id="future-work"><strong>Future Work</strong></h2> <p>Many interesting ideas are still left to explore to continue exploring the idea of models creating their own tokenization. For now a limiting factor is the actual BPE encoding process that takes longer and longer as the model creates new tokens. That’s because the encoding process is done in Python, so it’s quite slow and can’t be precalculated as you would do with fixed BPE encodings. To give a sense of the slowdown, the training loop starts at ~11it/s on a GTX970 and finished at roughly 10s/it. That’s a 100x slowdown over the course of the training, with only 1k tokens in the end, far from the 50k used by GPT-2 for instance.</p> <p>It’s going to be an actual requirement to train on larger and more representative datasets. Training on bigger datasets would help us understand how important are those multi word tokens and maybe what are those multi words. The token “(…) <strong>asked</strong>” was pretty surprising to me, I’m eager to see what else can be discovered.</p> <p>The actual epsilon used was 40% which actually quite a big (value was chosen with trial and error, to get a small but not null rejection rate of new tokens, to add tokens as fast as possible but not making too many mistakes). That value probably has a sweet spot depending on the number of current tokens, after speeding up the process it would be interesting to look at the best value for epsilon as a function of the number of tokens.</p>nicolasByte-pair encodings (BPE) are now very commonly used in NLP. In GPT-2, Byte-pair encodings are used to preformat the raw texts before feeding the model. But this is a relatively costly step for your preprocessing and has some limitations. For instance, you have to split your data on spaces if you want your byte pair algorithm to compute in reasonable time.