/** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // https://github.com/tensorflow/tfjs-models/blob/master/universal-sentence-encoder/src/tokenizer/trie.ts class TrieNode { constructor(key) { this.key = key; this.parent = null; this.children = {}; this.end = false; } getWord() { const output = []; let node = this; while (node !== null) { if (node.key !== null) { output.unshift(node.key); } node = node.parent; } return [output, this.score, this.index]; } } class Trie { constructor() { this.root = new TrieNode(null); } insert(word, score, index) { let node = this.root; const symbols = []; for (const symbol of word) { symbols.push(symbol); } for (let i = 0; i < symbols.length; i++) { if (!node.children[symbols[i]]) { node.children[symbols[i]] = new TrieNode(symbols[i]); node.children[symbols[i]].parent = node; } node = node.children[symbols[i]]; if (i === symbols.length - 1) { node.end = true; node.score = score; node.index = index; } } } find(ss) { let node = this.root; let iter = 0; while (iter < ss.length && node != null) { node = node.children[ss[iter]]; iter++; } return node; } } const bert = { loadTokenizer: async () => { const tokenizer = new BertTokenizer(); await tokenizer.load(); return tokenizer; } }; class BertTokenizer { constructor() { this.separator = '\u2581'; this.UNK_INDEX = 100; } async load() { this.vocab = await this.loadVocab(); this.trie = new Trie(); // Actual tokens start at 999. for (let i = 999; i < this.vocab.length; i++) { const word = this.vocab[i]; this.trie.insert(word, 1, i); } this.token2Id = {} this.vocab.forEach((d, i) => { this.token2Id[d] = i }) this.decode = a => a.map(d => this.vocab[d].replace('▁', ' ')).join('') // Adds [CLS] and [SEP] this.tokenizeCLS = str => [101, ...this.tokenize(str), 102] } async loadVocab() { if (!window.bertProcessedVocab){ window.bertProcessedVocab = await (await fetch('data/processed_vocab.json')).json() } return window.bertProcessedVocab } processInput(text) { const words = text.split(' '); return words.map(word => { if (word !== '[CLS]' && word !== '[SEP]') { return this.separator + word.toLowerCase().normalize('NFKC'); } return word; }); } tokenize(text) { // Source: // https://github.com/google-research/bert/blob/88a817c37f788702a363ff935fd173b6dc6ac0d6/tokenization.py#L311 let outputTokens = []; const words = this.processInput(text); for (let i = 0; i < words.length; i++) { const chars = []; for (const symbol of words[i]) { chars.push(symbol); } let isUnknown = false; let start = 0; const subTokens = []; const charsLength = chars.length; while (start < charsLength) { let end = charsLength; let currIndex; while (start < end) { let substr = chars.slice(start, end).join(''); const match = this.trie.find(substr); if (match != null && match.end) { currIndex = match.getWord()[2]; break; } end = end - 1; } if (currIndex == null) { isUnknown = true; break; } subTokens.push(currIndex); start = end; } if (isUnknown) { outputTokens.push(this.UNK_INDEX); } else { outputTokens = outputTokens.concat(subTokens); } } return outputTokens; } }